1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
29 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/shape_refiner.h"
34 #include "tensorflow/core/framework/graph_to_functiondef.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/control_flow.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/util/dump_graph.h"
44
45 using xla::StatusOr;
46
47 namespace tensorflow {
48 namespace functionalize_cond {
49
operator <(const AncestorNode & other) const50 bool AncestorNode::operator<(const AncestorNode& other) const {
51 return (output_tensor.node->id() < other.output_tensor.node->id()) ||
52 (output_tensor.node->id() == other.output_tensor.node->id() &&
53 output_tensor.index < other.output_tensor.index) ||
54 (output_tensor.node->id() == other.output_tensor.node->id() &&
55 output_tensor.index == other.output_tensor.index &&
56 type < other.type);
57 }
58
operator ==(const AncestorNode & other) const59 bool AncestorNode::operator==(const AncestorNode& other) const {
60 return output_tensor.node->id() == other.output_tensor.node->id() &&
61 output_tensor.index == other.output_tensor.index && type == other.type;
62 }
63
operator ()(const AncestorNode & ancestor) const64 size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
65 size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
66 h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
67 return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
68 }
69
70 typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
71 ClusterTuple;
72
73 struct ClusterTupleLessThan {
operator ()tensorflow::functionalize_cond::ClusterTupleLessThan74 bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
75 if (std::tie(std::get<0>(a), std::get<1>(a)) <
76 std::tie(std::get<0>(b), std::get<1>(b))) {
77 return true;
78 } else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
79 std::tie(std::get<0>(b), std::get<1>(b))) {
80 return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
81 } else {
82 return false;
83 }
84 }
85 };
86
87 // TODO(jpienaar): Move to OutputTensor.
DebugString(const OutputTensor & tensor)88 string DebugString(const OutputTensor& tensor) {
89 return absl::StrCat(tensor.node->name(), ":", tensor.index);
90 }
91
Branch_Name(BranchType b)92 string Branch_Name(BranchType b) {
93 switch (b) {
94 case BranchType::kElseBranch:
95 return "else";
96 case BranchType::kThenBranch:
97 return "then";
98 case BranchType::kBoth:
99 return "both";
100 case BranchType::kNeither:
101 return "neither";
102 }
103 }
104
DebugString(StateMap::CondId cond_state)105 string DebugString(StateMap::CondId cond_state) {
106 if (cond_state == nullptr || cond_state->empty()) return "{}";
107 using value_type = StateMap::CondState::value_type;
108 return absl::StrCat(
109 "{",
110 absl::StrJoin(*cond_state, ", ",
111 [](string* output, const value_type& pred_branch) {
112 const OutputTensor& pred = pred_branch.first;
113 const BranchType& branch = pred_branch.second;
114 if (branch == BranchType::kNeither)
115 absl::StrAppend(output, "d");
116 else
117 absl::StrAppend(output, "s(", DebugString(pred), ",",
118 Branch_Name(branch), ")");
119 }),
120 "}");
121 }
122
123 // Returns the predicate of a switch.
GetSwitchPredicate(const Node & switch_node,OutputTensor * pred)124 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
125 const Edge* pred_edge;
126 TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
127 // The predicate can be preceded by a identity node. Look through
128 // identity nodes to predicate.
129 while (pred_edge->src()->IsIdentity()) {
130 TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
131 }
132 *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
133 return Status::OK();
134 }
135
GetSwitchValue(const Node & switch_node,OutputTensor * val)136 Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
137 const Edge* val_edge;
138 TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
139 *val = OutputTensor(val_edge->src(), val_edge->src_output());
140 return Status::OK();
141 }
142
operator ()(const OutputTensor & lhs,const OutputTensor & rhs) const143 bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
144 const OutputTensor& rhs) const {
145 return (lhs.node->id() < rhs.node->id()) ||
146 (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
147 }
148
149 struct CondStateLess {
operator ()tensorflow::functionalize_cond::CondStateLess150 bool operator()(const StateMap::CondState::value_type& lhs,
151 const StateMap::CondState::value_type& rhs) const {
152 if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
153 return true;
154 if (lhs.first.node->id() == rhs.first.node->id() &&
155 lhs.first.index == rhs.first.index)
156 return lhs.second < rhs.second;
157 return false;
158 }
159 };
160
StateMap(Graph * graph)161 StateMap::StateMap(Graph* graph) {
162 node_to_condid_map_.resize(graph->num_node_ids());
163 node_to_ancestorid_map_.resize(graph->num_node_ids());
164 // Initialize the dead state (empty state is designated with a nullptr).
165 dead_id_ = GetCondId(
166 {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
167 }
168
IsDead(StateMap::CondId id) const169 bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
170
IsEmpty(StateMap::CondId id) const171 bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
172
operator ()(const StateMap::CondState & map) const173 size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
174 if (map.empty()) return 0;
175 // Compute hash of the front element.
176 auto it = map.begin();
177 size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
178 hash<BranchType>()(it->second));
179 for (++it; it != map.end(); ++it) {
180 // Combine the has with the different elements in the map.
181 h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
182 hash<BranchType>()(it->second)));
183 }
184 return h;
185 }
186
operator ()(const StateMap::AncestorState & map) const187 size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
188 if (map.empty()) return 0;
189 // Compute hash of the front element.
190 auto it = map.begin();
191 size_t h = AncestorNode::Hash()(*it);
192 for (++it; it != map.end(); ++it) {
193 // Combine the has with the different elements in the map.
194 h = Hash64Combine(h, AncestorNode::Hash()(*it));
195 }
196 return h;
197 }
198
199 // CondArgNode represents a input to the conditional and its corresponding
200 // switch nodes.
201 struct CondArgNode {
CondArgNodetensorflow::functionalize_cond::CondArgNode202 explicit CondArgNode(Node* src, int src_output)
203 : src(src), src_output(src_output) {}
204
ToStringtensorflow::functionalize_cond::CondArgNode205 string ToString() const {
206 return absl::StrCat("src=", src->name(), ":", src_output,
207 " switches=", NodesToString(switches));
208 }
209
210 Node* src;
211 int src_output;
212 std::array<Node*, 2> branch_copy;
213 std::vector<Node*> switches;
214 };
215 using CondArgNodes = std::vector<CondArgNode>;
216
DebugString(const CondArgNodes & nodes)217 string DebugString(const CondArgNodes& nodes) {
218 return absl::StrCat(
219 "[",
220 absl::StrJoin(nodes, ", ",
221 [](string* output, const CondArgNode& node) {
222 absl::StrAppend(output, node.ToString());
223 }),
224 "]");
225 }
226
LookupCondId(const Node * node) const227 StateMap::CondId StateMap::LookupCondId(const Node* node) const {
228 const int64 map_size = node_to_condid_map_.size();
229 if (node->id() < map_size) return node_to_condid_map_[node->id()];
230 return added_node_condid_mapping_.at(node->id());
231 }
232
GetCondId(const StateMap::CondState & state)233 StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
234 if (state.empty()) return nullptr;
235 return &*condstate_set_.insert(state).first;
236 }
237
ResetCondId(const Node * node,StateMap::CondId id)238 void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
239 const int64 map_size = node_to_condid_map_.size();
240 if (node->id() < map_size)
241 node_to_condid_map_[node->id()] = id;
242 else
243 added_node_condid_mapping_[node->id()] = id;
244 }
245
LookupAncestorId(const Node * node) const246 StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
247 const int64 map_size = node_to_ancestorid_map_.size();
248 if (node->id() < map_size) return node_to_ancestorid_map_[node->id()];
249 return added_node_ancestorid_mapping_.at(node->id());
250 }
251
GetAncestorId(const StateMap::AncestorState & state)252 StateMap::AncestorId StateMap::GetAncestorId(
253 const StateMap::AncestorState& state) {
254 if (state.empty()) return nullptr;
255 return &*ancestorstate_set_.insert(state).first;
256 }
257
ResetAncestorId(const Node * node,StateMap::AncestorId id)258 void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
259 const int64 map_size = node_to_ancestorid_map_.size();
260 if (node->id() < map_size)
261 node_to_ancestorid_map_[node->id()] = id;
262 else
263 added_node_ancestorid_mapping_[node->id()] = id;
264 }
265
MarkDead(const Node * node)266 void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
267
CondStateToString(const Node * node) const268 string StateMap::CondStateToString(const Node* node) const {
269 return CondStateToString(LookupCondId(node));
270 }
271
CondStateToString(StateMap::CondId id) const272 string StateMap::CondStateToString(StateMap::CondId id) const {
273 return DebugString(id);
274 }
275
AncestorStateToString(const Node * node) const276 string StateMap::AncestorStateToString(const Node* node) const {
277 if (auto id = LookupAncestorId(node)) {
278 return absl::StrCat(
279 "{",
280 absl::StrJoin(*id, ",",
281 [](string* output, const AncestorNode& ancestor) {
282 absl::StrAppend(output,
283 ancestor.output_tensor.node->name(),
284 ":", ancestor.output_tensor.index);
285 }),
286 "}");
287 }
288 return "{}";
289 }
290
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)291 FunctionalizeCond::FunctionalizeCond(Graph* graph,
292 FunctionLibraryDefinition* library,
293 const NodeFilter& node_filter)
294 : state_map_(graph),
295 library_(library),
296 graph_(graph),
297 node_filter_(node_filter) {}
298
299 // Class representing the merge/switch nodes that will become a conditional.
300 class Conditional {
301 public:
302 Conditional(OutputTensor predicate, FunctionalizeCond* parent,
303 StateMap* cond_state_map, const ShapeRefiner& refiner);
304
305 // Adds merge node that is part of this conditional.
306 Status AddMerge(Node* m);
307
308 // Constructs an If node from the merge nodes.
309 Status BuildAndReplace(
310 Graph* graph, FunctionLibraryDefinition* library,
311 std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
312
313 private:
314 // Extracts the then/else bodies: creates new graphs with the nodes
315 // corresponding to the nodes in the then/else branches as of this conditional
316 // as function bodies.
317 Status ExtractBodies(Graph* graph);
318
319 // Builds the arguments that are the input to the If.
320 Status BuildArgumentNodes();
321
322 // Builds the If node for the extracted bodies with the given predicate.
323 Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
324
325 // Adds input edges to If node.
326 Status AddInputEdges(
327 Graph* graph,
328 const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
329
330 // Adds output edges from If node.
331 // Record new output tensor for all Merge nodes in 'merge_to_replacement'.
332 Status AddOutputEdges(
333 Graph* graph,
334 std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
335
336 // Adds switch node that is part of this conditional.
337 Status AddSwitch(Node* s);
338
339 // Adds a switch node along the edge and rewire the edge to go via the switch.
340 Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
341 Graph* graph);
342
343 // Internal name of conditional. The name is based on the first merge node
344 // added.
345 string name() const;
346
347 // The FunctionalizeCond instance that created this.
348 FunctionalizeCond* parent_;
349
350 // Mapping between nodes and their cond state.
351 StateMap* state_map_;
352
353 // The predicate of the conditional.
354 OutputTensor predicate_;
355
356 // Shape refiner of ops in the graph.
357 const ShapeRefiner& refiner_;
358
359 // The predicate of the switches of the conditional. This may be different
360 // than predicate (which is initialized from the original graph) as the
361 // predicate could be the output of a newly created If node.
362 OutputTensor switch_predicate_;
363
364 // Switch nodes in graph that are part of this conditional.
365 std::set<Node*, NodeCmpByNameResourcesLast> switches_;
366
367 // Merge nodes in graph that are part of this conditional.
368 std::set<Node*, NodeCmpByNameResourcesLast> merges_;
369
370 // Vector of control inputs from outside the conditional to a node inside.
371 std::vector<Node*> external_control_inputs_;
372 std::vector<Node*> external_control_outputs_;
373
374 // Graphs corresponding to the then and else branch.
375 std::array<std::unique_ptr<Graph>, 2> bodies_;
376
377 // Maps from graph_ to the branch body's graph.
378 std::array<std::vector<Node*>, 2> node_maps_;
379
380 // The argument nodes created for the switches.
381 CondArgNodes cond_arg_nodes_;
382
383 // The constructed If node.
384 Node* if_node_ = nullptr;
385
386 // Whether the merge nodes of this conditional have been replaced.
387 bool replaced_ = false;
388 };
389
Conditional(OutputTensor predicate,FunctionalizeCond * parent,StateMap * cond_state_map,const ShapeRefiner & refiner)390 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
391 StateMap* cond_state_map, const ShapeRefiner& refiner)
392 : parent_(parent),
393 state_map_(cond_state_map),
394 predicate_(predicate),
395 refiner_(refiner) {}
396
AddMerge(Node * m)397 Status Conditional::AddMerge(Node* m) {
398 merges_.insert(m);
399 return Status::OK();
400 }
401
AddSwitch(Node * s)402 Status Conditional::AddSwitch(Node* s) {
403 VLOG(5) << "Adding switch " << s->DebugString();
404 OutputTensor predicate;
405 TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
406 if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
407 if (!(switch_predicate_ == predicate)) {
408 return errors::InvalidArgument(
409 "Merge nodes ", NodesToString(merges_),
410 " directly dominated by switch nodes with different predicates (",
411 DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
412 }
413 switches_.insert(s);
414 parent_->AddSwitchId(s->id());
415 return Status::OK();
416 }
417
BuildArgumentNodes()418 Status Conditional::BuildArgumentNodes() {
419 VLOG(1) << "Build function arguments";
420 struct Hash {
421 size_t operator()(const std::pair<Node*, int>& item) const {
422 return Hash64Combine(hash<Node*>()(item.first),
423 std::hash<int>()(item.second));
424 }
425 };
426
427 std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
428 for (Node* switch_node : switches_) {
429 const Edge* e;
430 TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
431 std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
432 if (input_index.find(key) == input_index.end()) {
433 input_index[key] = cond_arg_nodes_.size();
434 cond_arg_nodes_.emplace_back(key.first, key.second);
435 }
436 cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
437 }
438 VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
439
440 int arg_count = 0;
441 for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
442 DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
443 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
444 int branch_index = static_cast<int>(branch);
445 TF_RETURN_IF_ERROR(
446 NodeBuilder(absl::StrCat("_Arg", arg_count),
447 FunctionLibraryDefinition::kArgOp)
448 .Attr("T", dtype)
449 .Attr("index", arg_count)
450 .Finalize(bodies_[branch_index].get(),
451 &cond_arg_node.branch_copy[branch_index]));
452 }
453 for (Node* node : cond_arg_node.switches) {
454 for (const Edge* e : node->out_edges()) {
455 if (e->IsControlEdge()) continue;
456 int branch_index = e->src_output();
457 Node* src_copy = cond_arg_node.branch_copy[branch_index];
458 Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
459
460 // The graph may contain dead switch nodes,
461 if (dst_copy == nullptr) continue;
462
463 TF_RET_CHECK(dst_copy != nullptr)
464 << "Unable to find copied node for " << e->dst()->DebugString()
465 << " on branch " << Branch_Name(BranchType(branch_index));
466 // If the input goes directly to a merge then the merge has
467 // been replaced by a retval so the dst input is 0 instead of
468 // dst_input.
469 int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
470 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
471 }
472 }
473 ++arg_count;
474 }
475
476 // Verify that all retvals have an input.
477 // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
478 // input.
479 for (Node* m : merges_) {
480 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
481 bool has_input = false;
482 for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
483 if (!e->IsControlEdge()) {
484 has_input = true;
485 break;
486 }
487 }
488 if (!has_input) {
489 return errors::Internal(
490 "Failed to functionalize control flow with merge ",
491 FormatNodeForError(*m), " that doesn't have input on ",
492 Branch_Name(branch), " branch.");
493 }
494 }
495 }
496
497 return Status::OK();
498 }
499
AddSwitchNodeAlongEdge(const Edge * edge,BranchType branch,Graph * graph)500 Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
501 Graph* graph) {
502 // Previously we had edge:
503 // src:src_output ---- edge ----> dst:dst_input
504 // post this we have (in graph)
505 // src:src_output --> switch<pred> --- new_edge --> dst:dst_input
506
507 // TODO(jpienaar): One could keep a map caching the extra switch nodes added
508 // to avoid adding another switch to feed a value for which a switch was
509 // already added.
510 Node* switch_node;
511 Node* src = edge->src();
512 int src_output = edge->src_output();
513 TF_RETURN_IF_ERROR(
514 NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
515 "Switch")
516 .Input(src, src_output)
517 .Input(const_cast<Node*>(predicate_.node), predicate_.index)
518 .Finalize(graph, &switch_node));
519 state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
520 state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
521
522 Node* dst = edge->dst();
523 int dst_input = edge->dst_input();
524 graph->RemoveEdge(edge);
525 graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
526 return AddSwitch(switch_node);
527 }
528
ExtractBodies(Graph * graph)529 Status Conditional::ExtractBodies(Graph* graph) {
530 VLOG(2) << "Extracting bodies for " << name();
531 for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
532 bodies_[static_cast<int>(b)] =
533 absl::make_unique<Graph>(graph->op_registry());
534 }
535
536 auto find_branch = [&](const Edge* e) {
537 const auto& id = state_map_->LookupCondId(e->src());
538 return IsSwitch(e->src()) ? BranchType(e->src_output())
539 : state_map_->FindBranchOf(id, predicate_);
540 };
541
542 std::array<std::vector<Node*>, 2> stacks;
543 VLOG(5) << "Merges: " << NodesToString(merges_);
544 for (Node* m : merges_) {
545 VLOG(5) << "For merge: " << m->DebugString() << " "
546 << state_map_->CondStateToString(m);
547 for (auto e : m->in_edges()) {
548 if (e->IsControlEdge()) continue;
549 BranchType branch = find_branch(e);
550 TF_RET_CHECK(branch == BranchType::kThenBranch ||
551 branch == BranchType::kElseBranch)
552 << "Error: " << e->src()->name()
553 << " is not on either then or else branch (" << Branch_Name(branch)
554 << ") for predicate " << DebugString(predicate_) << " ["
555 << DebugString(state_map_->LookupCondId(e->src())) << "].";
556 Node* src = e->src();
557 if (IsSwitch(src)) {
558 // Switch node outputs and dependencies are handled separately.
559 TF_RETURN_IF_ERROR(AddSwitch(src));
560 } else {
561 stacks[static_cast<int>(branch)].push_back(src);
562 }
563 }
564 }
565
566 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
567 int branch_index = static_cast<int>(branch);
568 auto output = bodies_[branch_index].get();
569 auto& stack = stacks[branch_index];
570 VLOG(5) << "In branch: " << Branch_Name(branch) << " "
571 << NodesToString(stack);
572 std::vector<bool> visited(graph->num_node_ids(), false);
573 node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
574 auto& node_map = node_maps_[branch_index];
575
576 while (!stack.empty()) {
577 Node* n = stack.back();
578 stack.pop_back();
579
580 if (visited.at(n->id())) continue;
581 visited[n->id()] = true;
582
583 // Verify output edges and record control edges exiting scope.
584 for (const Edge* e : n->out_edges()) {
585 Node* dst = e->dst();
586 if (IsMerge(dst)) continue;
587 Node* src = e->src();
588
589 auto dst_id = state_map_->LookupCondId(dst);
590 auto src_id = state_map_->LookupCondId(src);
591 if (dst_id != src_id) {
592 if (e->IsControlEdge()) {
593 external_control_outputs_.push_back(e->src());
594 } else {
595 // Constants are treated specially to workaround the case of
596 // non-dominated constant nodes.
597 if (!IsConstant(src)) {
598 // TODO(b/78882471): A node that feeds into two different
599 // CondState is not necessarily an error so log a warning for now
600 // but revisit to improve the testing to enable making this an
601 // error.
602 LOG(WARNING) << errors::InvalidArgument(
603 "Graph contains node ", FormatNodeForError(*src),
604 " that feeds into node ", FormatNodeForError(*dst),
605 " but these nodes are in different control contexts (",
606 DebugString(src_id), " vs ", DebugString(dst_id),
607 " (detected during out edge testing)");
608 }
609 }
610 }
611 }
612
613 // Copying incoming edges to dst node. Iterate over a copy of the edges
614 // as they could be mutated during iteration.
615 std::vector<const Edge*> in_edges(n->in_edges().begin(),
616 n->in_edges().end());
617 // Sort in_edges to make sure nodes are copied in a deterministic order.
618 std::sort(
619 in_edges.begin(), in_edges.end(), [](const Edge* a, const Edge* b) {
620 int a_src_output = a->src_output(), b_src_output = b->src_output();
621 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
622 return std::tie(a_src_output, a_name) <
623 std::tie(b_src_output, b_name);
624 });
625 for (const Edge* e : in_edges) {
626 Node* src = e->src();
627 // Skip src/dst node.
628 if (!src->IsOp()) continue;
629
630 Node* dst = e->dst();
631 if (IsSwitch(src)) {
632 // Switch node outputs and dependencies are handled separately.
633 TF_RETURN_IF_ERROR(AddSwitch(src));
634 continue;
635 }
636
637 // Verify input is from the same context.
638 auto src_id = state_map_->LookupCondId(src);
639 auto dst_id = state_map_->LookupCondId(dst);
640 if (IsMerge(dst) || src_id == dst_id) {
641 // TODO(jpienaar): The merge case can be more strict.
642 if (node_map.at(src->id()) == nullptr) {
643 node_map.at(src->id()) = output->CopyNode(src);
644 stack.push_back(src);
645 }
646 } else if (e->IsControlEdge()) {
647 // Here we have a control flow edge between src and dst that are not
648 // in the same context. This is an external control dependency except
649 // for one case: where the only difference between CondId of e->src()
650 // and CondId of e->dst() is that e->src() has {PRED, kNeither} and
651 // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
652 // gradients code for tf.cond(), where e->src() is a control pivot
653 // node for a branch and e->dst() is a data node in that branch.
654 bool is_external_control_input = true;
655 if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
656 std::vector<StateMap::CondState::value_type> diff;
657 std::set_symmetric_difference(
658 src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
659 std::back_inserter(diff), CondStateLess());
660 if (diff.size() == 2 && diff[0].first == diff[1].first &&
661 (diff[0].second == BranchType::kNeither ||
662 diff[1].second == BranchType::kNeither)) {
663 auto src_branch = src_id->find(diff[0].first);
664 if (src_branch != src_id->end() &&
665 src_branch->second == BranchType::kNeither) {
666 is_external_control_input = false;
667 }
668 }
669 }
670 if (is_external_control_input) {
671 external_control_inputs_.push_back(src);
672 }
673 } else {
674 // This shouldn't happen, this means we have an external data input
675 // not entering via a switch node. Work around this by for
676 // * constant nodes copy them;
677 // * non-constant nodes, insert a switch along the edge;
678 if (IsConstant(src)) {
679 // Check if constant node was added already. It is possible to have
680 // multiple uses of a constant node.
681 if (node_map.at(src->id()) == nullptr) {
682 node_map.at(src->id()) = output->CopyNode(src);
683 }
684 } else {
685 StateMap::CondState state = *dst_id;
686 state.erase(predicate_);
687 if (state_map_->GetCondId(state) == src_id) {
688 TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
689 continue;
690 } else {
691 return errors::InvalidArgument(
692 "Graph contains node ", FormatNodeForError(*src),
693 " that feeds into node ", FormatNodeForError(*dst),
694 " but these nodes are in different control contexts (",
695 DebugString(src_id), " vs ", DebugString(dst_id),
696 " (detected during in edge testing)");
697 }
698 }
699 }
700
701 Node* src_copy = node_map.at(e->src()->id());
702 int src_output = e->src_output();
703 if (node_map.at(dst->id()) == nullptr) {
704 node_map.at(dst->id()) = output->CopyNode(dst);
705 }
706 Node* dst_copy = node_map.at(e->dst()->id());
707 if (e->IsControlEdge()) {
708 // Skip control inputs from external context.
709 if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
710 } else {
711 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
712 }
713 }
714 }
715 }
716
717 // Build return values from the merge nodes.
718 int index = 0;
719 for (Node* m : merges_) {
720 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
721 int branch_index = static_cast<int>(branch);
722 auto& node_map = node_maps_[branch_index];
723 auto output = bodies_[branch_index].get();
724 TF_ASSIGN_OR_RETURN(node_map[m->id()],
725 BuildRetvalNode(output, m->output_type(0), index));
726 }
727 ++index;
728
729 // Connect the input to the merge_ with the retval, except if it is a
730 // Switch node, which is handled separately.
731 for (auto e : m->in_edges()) {
732 if (e->IsControlEdge()) continue;
733 int branch_index = static_cast<int>(find_branch(e));
734 auto& node_map = node_maps_[branch_index];
735 auto output = bodies_[branch_index].get();
736 Node* in = e->src();
737 if (!IsSwitch(in)) {
738 if (node_map.at(in->id()) == nullptr) {
739 node_map[in->id()] = output->CopyNode(in);
740 }
741 output->AddEdge(node_map[in->id()], e->src_output(),
742 node_map.at(m->id()), 0);
743 }
744 }
745 }
746 return Status::OK();
747 }
748
BuildIfNode(Graph * graph,FunctionLibraryDefinition * library)749 Status Conditional::BuildIfNode(Graph* graph,
750 FunctionLibraryDefinition* library) {
751 VLOG(2) << "Build cond function for " << name();
752 NodeDebugInfo debug_info((*merges_.begin())->def());
753 NodeDefBuilder builder(name(), "If", library, &debug_info);
754 const string branch_name[] = {"else_branch", "then_branch"};
755 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
756 int branch_index = static_cast<int>(branch);
757
758 NameAttrList body_name;
759 body_name.set_name(library->UniqueFunctionName(
760 absl::StrCat("_functionalize_if_", branch_name[branch_index], "_")));
761
762 VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
763 << "): "
764 << DumpGraphToFile(
765 "functionalize_cond_body_" + branch_name[branch_index],
766 *bodies_[branch_index], nullptr);
767
768 FunctionDef body_fdef;
769 TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
770 body_name.name(), &body_fdef));
771 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
772 builder.Attr(branch_name[branch_index], body_name);
773 }
774
775 VLOG(3) << "Build input type";
776 std::vector<NodeDefBuilder::NodeOut> inputs;
777 DataTypeVector in_arg_types;
778 for (auto& kv : cond_arg_nodes_) {
779 bool inserted = false;
780 for (const Node* arg : kv.switches) {
781 const Edge* in_edge;
782 TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
783 if (in_edge->IsControlEdge()) {
784 builder.ControlInput(in_edge->src()->name());
785 } else {
786 if (!inserted) {
787 DataType dtype = arg->input_type(0);
788 inputs.emplace_back(NodeDefBuilder::NodeOut(
789 in_edge->src()->name(), in_edge->src_output(), dtype));
790 in_arg_types.push_back(dtype);
791 inserted = true;
792 }
793 }
794 }
795 }
796 builder.Attr("Tin", in_arg_types);
797
798 DataTypeVector out_type;
799 std::vector<PartialTensorShape> output_shapes;
800 output_shapes.reserve(merges_.size());
801 for (const Node* merge : merges_) {
802 DataType dtype = merge->output_type(0);
803 TensorShapeProto shape;
804 if (auto* shape_ctx = refiner_.GetContext(merge)) {
805 shape_inference::ShapeHandle handle;
806 shape_ctx->ShapeHandleToProto(shape_ctx->output(0), &shape);
807 }
808 out_type.push_back(dtype);
809 output_shapes.push_back(shape);
810 }
811 builder.Attr("Tout", out_type);
812 VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
813 builder.Attr("output_shapes", output_shapes);
814 VLOG(3) << "Build output shapes: "
815 << PartialTensorShapeUtils::PartialShapeListString(output_shapes);
816
817 builder.Attr("Tcond", DT_BOOL);
818 // Add some internal attributes which need to be propagated.
819 // TODO(b/160275126): attributes shouldn't be hard-coded here
820 for (const char* attr_name :
821 {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
822 kTpuReplicateAttrName}) {
823 string attr_val;
824 if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
825 builder.Attr(attr_name, attr_val);
826 }
827 }
828 builder.Device(predicate_.node->assigned_device_name());
829 // Conditional should be the first input ...
830 builder.Input(
831 NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
832 predicate_.node->output_type(predicate_.index)));
833 // ... followed by the other inputs.
834 builder.Input(inputs);
835
836 VLOG(3) << "Build If node";
837 NodeDef if_def;
838 TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
839 TF_ASSIGN_OR_RETURN(if_node_,
840 parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
841
842 return Status::OK();
843 }
844
AddInputEdges(Graph * graph,const std::unordered_map<Node *,OutputTensor> & merge_to_replacement)845 Status Conditional::AddInputEdges(
846 Graph* graph,
847 const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
848 VLOG(2) << "AddInputEdges for " << if_node_->name();
849 int index = 0;
850 // Add predicate input.
851 if (predicate_.node->IsMerge()) {
852 // If the predicate is a Merge node, we should not use Merge output as
853 // predicate. Instead, we should use the corresponding If output in
854 // 'merge_to_replacement'. Otherwise, this Conditional's If node is still
855 // connected to the predicate Merge node; and when we call
856 // DeleteReachableAndDeadNodes(), the predicate Merge node and this
857 // Conditional's If node will be removed.
858 auto iter = merge_to_replacement.find(predicate_.node);
859 if (iter == merge_to_replacement.end()) {
860 return errors::Internal("Cannot find replacement for Merge node ",
861 predicate_.node->name());
862 }
863 graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
864 } else {
865 graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
866 if_node_, index++);
867 }
868 // Add function body inputs.
869 for (auto& arg : cond_arg_nodes_) {
870 if (arg.src_output == Graph::kControlSlot) {
871 graph->AddControlEdge(arg.src, if_node_);
872 } else {
873 graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
874 }
875 }
876 for (Node* n : external_control_inputs_) {
877 graph->AddControlEdge(n, if_node_);
878 }
879 return Status::OK();
880 }
881
AddOutputEdges(Graph * graph,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)882 Status Conditional::AddOutputEdges(
883 Graph* graph,
884 std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
885 VLOG(2) << "AddOutputEdges for " << if_node_->name();
886 int i = 0;
887 for (Node* node : merges_) {
888 TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
889 std::vector<const Edge*> edges(node->out_edges().begin(),
890 node->out_edges().end());
891 for (const Edge* edge : edges) {
892 Node* dst = edge->dst();
893 int dst_input = edge->dst_input();
894 if (edge->src_output() > 0) {
895 return errors::Unimplemented("Output of index (", edge->src_output(),
896 ") of merge node ",
897 FormatNodeForError(*node));
898 }
899
900 bool control_edge = edge->IsControlEdge();
901 graph->RemoveEdge(edge);
902 if (control_edge) {
903 graph->AddControlEdge(if_node_, dst);
904 } else {
905 graph->AddEdge(if_node_, i, dst, dst_input);
906 }
907 }
908
909 // Record corresponding output tensor in 'merge_to_replacement'.
910 (*merge_to_replacement)[node] = OutputTensor{if_node_, i};
911
912 ++i;
913 }
914 for (Node* n : external_control_outputs_) {
915 graph->AddControlEdge(if_node_, n);
916 }
917
918 return Status::OK();
919 }
920
BuildAndReplace(Graph * graph,FunctionLibraryDefinition * library,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)921 Status Conditional::BuildAndReplace(
922 Graph* graph, FunctionLibraryDefinition* library,
923 std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
924 VLOG(1) << "Build If and replace merge nodes "
925 << NodesToString(this->merges_);
926 if (replaced_) return Status::OK();
927
928 TF_RETURN_IF_ERROR(ExtractBodies(graph));
929 TF_RETURN_IF_ERROR(BuildArgumentNodes());
930
931 if (VLOG_IS_ON(3)) {
932 LOG(INFO) << "Extracted bodies:";
933 for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
934 int branch_index = static_cast<int>(branch);
935 auto output = bodies_[branch_index].get();
936 LOG(INFO) << Branch_Name(branch) << ": "
937 << DebugString(output->ToGraphDefDebug());
938 }
939 }
940
941 TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
942 TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
943 TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
944 TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
945
946 // Check that the if_node doesn't feed into itself.
947 TF_RETURN_WITH_CONTEXT_IF_ERROR(
948 CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
949 "Converting to If failed.");
950
951 replaced_ = true;
952 return Status::OK();
953 }
954
name() const955 string Conditional::name() const {
956 CHECK(!merges_.empty());
957 return absl::StrCat((*merges_.begin())->name(), "_if");
958 }
959
AddIdentityNode(const Node * replacee,Node * if_node,int port)960 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
961 int port) {
962 NodeBuilder id_builder(replacee->name(), "Identity");
963 id_builder.Input(if_node, port);
964 string outside_compilation;
965 if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
966 &outside_compilation)
967 .ok()) {
968 id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
969 }
970 Node* id;
971 TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
972 state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
973 state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
974 return Status::OK();
975 }
976
AddIfNode(const NodeDef & def,const Node * replacee,const OutputTensor & predicate)977 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
978 const Node* replacee,
979 const OutputTensor& predicate) {
980 Status status;
981 Node* ret = graph_->AddNode(def, &status);
982 TF_RETURN_IF_ERROR(status);
983 VLOG(1) << "Adding If for " << replacee->name();
984 StateMap::CondId id = state_map_.LookupCondId(replacee);
985 if (id) {
986 StateMap::CondState state = *id;
987 state.erase(predicate);
988 state_map_.ResetCondId(ret, state_map_.GetCondId(state));
989 } else {
990 state_map_.ResetCondId(ret, nullptr);
991 }
992
993 state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
994
995 return ret;
996 }
997
PropagateUpdatedState(const Node * replacee)998 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
999 VLOG(2) << "Propagating update state for " << replacee->name() << " "
1000 << state_map_.CondStateToString(replacee);
1001 // Redo topological sort as the order could have changed.
1002 // TODO(jpienaar): The original topological order could also be updated
1003 // dynamically if needed.
1004 std::vector<Node*> rev_topo_order;
1005 GetPostOrder(*graph_, &rev_topo_order);
1006
1007 // All the outputs of the new node could potentially be updated.
1008 std::unordered_set<Node*> changed;
1009 for (auto n : replacee->out_nodes())
1010 if (n->IsOp()) changed.insert(n);
1011
1012 // Iterate through the changed/possible changed nodes in topological order.
1013 for (auto it = rev_topo_order.rbegin();
1014 it != rev_topo_order.rend() && !changed.empty(); ++it) {
1015 if (changed.find(*it) != changed.end()) {
1016 // Update the node state.
1017 Node* n = *it;
1018 StateMap::CondId old_state = state_map_.LookupCondId(n);
1019 state_map_.ResetCondId(n, nullptr);
1020 TF_RETURN_IF_ERROR(DetermineCondState(n));
1021 if (state_map_.LookupCondId(n) != old_state) {
1022 for (auto out : n->out_nodes())
1023 if (out->IsOp()) changed.insert(out);
1024 }
1025 changed.erase(n);
1026 }
1027 }
1028 return Status::OK();
1029 }
1030
1031 // Returns the most restrictive branch of two branches or neither. This is the
1032 // meet operator of the BranchType lattice.
MeetBranch(const BranchType & lhs,const BranchType & rhs)1033 BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
1034 if (lhs == rhs) return lhs;
1035 if (lhs == BranchType::kNeither) return rhs;
1036 if (rhs == BranchType::kNeither) return lhs;
1037 if (lhs == BranchType::kBoth) return rhs;
1038 if (rhs == BranchType::kBoth) return lhs;
1039 return BranchType::kNeither;
1040 }
1041
FindBranchOf(CondId id,OutputTensor predicate) const1042 BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
1043 if (IsEmpty(id)) return BranchType::kNeither;
1044 const CondState& nodes = *id;
1045 auto it = nodes.find(predicate);
1046 if (it == nodes.end()) return BranchType::kNeither;
1047 return it->second;
1048 }
1049
JoinCondStatesNonMerge(StateMap::CondId src,StateMap::CondId dst)1050 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
1051 StateMap::CondId src, StateMap::CondId dst) {
1052 VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
1053 << "] and dst=" << DebugString(dst) << " [" << dst << "]";
1054
1055 if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
1056 if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
1057
1058 // Nothing to do if the CondState is the same.
1059 if (src == dst) return src;
1060
1061 StateMap::CondState both = *src;
1062 for (const auto& kv : *dst) {
1063 auto it = both.find(kv.first);
1064 if (it == both.end()) {
1065 both.insert(kv);
1066 } else {
1067 if (it->second != kv.second) {
1068 if (it->second == BranchType::kNeither) {
1069 // BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
1070 it->second = kv.second;
1071 } else if (kv.second == BranchType::kNeither) {
1072 // BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
1073 // No need to change it->second.
1074 } else {
1075 return errors::InvalidArgument(
1076 "Graph contains node with inputs predicated on incompatible "
1077 "predicates: ",
1078 DebugString(src), " and ", DebugString(dst));
1079 }
1080 }
1081 }
1082 }
1083 return state_map_.GetCondId(both);
1084 }
1085
JoinCondStatesMerge(Node * merge,StateMap::CondId src,StateMap::CondId dst)1086 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
1087 Node* merge, StateMap::CondId src, StateMap::CondId dst) {
1088 // Determine the flow state when joining two states for a merge
1089 // node. Combining the two states for a merge node is effectively performing a
1090 // disjunction of the states along the different input edges. For a merge that
1091 // can be transformed into an If the two inputs paths have to have a predicate
1092 // on which they differ (e.g., along one edge predicate `p` has to hold while
1093 // on another it should not). This function first determines this predicate
1094 // and then the resultant state is the common path between the two inputs
1095 // followed by s(p, both).
1096 VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
1097 << DebugString(dst);
1098 if (state_map_.IsEmpty(dst)) return src;
1099 if (state_map_.IsEmpty(src)) {
1100 return errors::Internal("Merge node ", merge->name(),
1101 " has input that's not in any CondContext.");
1102 }
1103
1104 if (state_map_.IsDead(src)) return src;
1105 if (state_map_.IsDead(dst)) return dst;
1106
1107 std::vector<StateMap::CondState::value_type> diff;
1108 StateMap::CondState merged;
1109 std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
1110 dst->end(), std::back_inserter(diff),
1111 CondStateLess());
1112 std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
1113 std::inserter(merged, merged.begin()), CondStateLess());
1114
1115 // Update mapping from merge node to predicate.
1116 if (diff.size() == 2) {
1117 auto pred = diff[0].first;
1118 bool different_branches = (diff[0].second != diff[1].second) &&
1119 (diff[0].second == BranchType::kThenBranch ||
1120 diff[0].second == BranchType::kElseBranch) &&
1121 (diff[1].second == BranchType::kThenBranch ||
1122 diff[1].second == BranchType::kElseBranch);
1123 if (!(pred == diff[1].first) || !different_branches)
1124 return errors::InvalidArgument(
1125 "Unable to determine predicate for merge node");
1126 merge_to_predicate_[merge] = pred;
1127 } else {
1128 return errors::InvalidArgument(
1129 "Merge of two inputs that differ on more than one predicate ",
1130 DebugString(src), " and ", DebugString(dst));
1131 }
1132
1133 return state_map_.GetCondId(merged);
1134 }
1135
StateAlongEdge(const Edge * e)1136 StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
1137 Node* src = e->src();
1138 StateMap::CondId id = state_map_.LookupCondId(e->src());
1139
1140 // Dead nodes only propagate dead state.
1141 if (state_map_.IsDead(id)) return id;
1142
1143 if (IsSwitch(src)) {
1144 StateMap::CondState state;
1145 if (id != nullptr) state = *id;
1146 OutputTensor predicate;
1147 TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
1148 if (e->IsControlEdge()) {
1149 // In gradients of tf.cond(), in each branch, we have a NoOp node as
1150 // control pivot. These NoOp nodes have control dependency from Switch
1151 // node. If we don't record this into CondState, branches might have
1152 // incorrect CondState (e.g. if the branch only has a Const data node).
1153 // We set it to kNeither because there is no way to tell whether it's
1154 // for true branch or false branch. This node's descendents might have
1155 // other incoming edges with defined BranchType, and we correctly handle
1156 // merging kNeither with other defined BranchType in StateAlongEdge().
1157 state[predicate] = BranchType::kNeither;
1158 } else {
1159 state[predicate] = BranchType(e->src_output());
1160 }
1161 return state_map_.GetCondId(state);
1162 }
1163 return id;
1164 }
1165
DetermineCondStateMerge(Node * dst)1166 Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
1167 // Only Merge nodes with two inputs are supported, but if this is a redundant
1168 // merge, then the dead edge may already have been removed (if due to a
1169 // switch) and so the input count would be incorrect.
1170 if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
1171
1172 int data_inputs = 0;
1173 for (auto e : dst->in_edges()) {
1174 Node* src = e->src();
1175 VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
1176 << state_map_.CondStateToString(src);
1177 if (!src->IsOp()) continue;
1178 if (!e->IsControlEdge()) ++data_inputs;
1179
1180 StateMap::CondId prop = StateAlongEdge(e);
1181 auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
1182 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1183 FormatNodeForError(*dst));
1184 state_map_.ResetCondId(dst, id_or.ValueOrDie());
1185 }
1186
1187 // Incomplete Merge nodes are not supported.
1188 if (data_inputs != 2) {
1189 return errors::Unimplemented(
1190 dst->name(), " only has ", data_inputs,
1191 " inputs, while only merge nodes with two inputs supported.");
1192 }
1193 return Status::OK();
1194 }
1195
DetermineCondStateNonMerge(Node * dst)1196 Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
1197 // Handle non-merge join.
1198 for (auto e : dst->in_edges()) {
1199 VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
1200 << state_map_.CondStateToString(dst);
1201 Node* src = e->src();
1202 if (!src->IsOp()) continue;
1203
1204 // Joining the state between the current and propagated state.
1205 StateMap::CondId prop = StateAlongEdge(e);
1206 auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
1207 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1208 FormatNodeForError(*dst));
1209 state_map_.ResetCondId(dst, id_or.ValueOrDie());
1210 }
1211 return Status::OK();
1212 }
1213
RemoveRedundantMerge(Node * node)1214 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
1215 // Handle redundant merge nodes. A merge node is considered redundant if
1216 // one input edge is dead while the other has a value.
1217 if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
1218
1219 const Edge* non_dead_edge = nullptr;
1220 for (auto e : node->in_edges()) {
1221 if (e->IsControlEdge()) continue;
1222 Node* src = e->src();
1223
1224 // Handle merge with dead state.
1225 const auto& src_id = state_map_.LookupCondId(src);
1226 if (!state_map_.IsDead(src_id)) {
1227 non_dead_edge = e;
1228 break;
1229 }
1230 }
1231
1232 if (non_dead_edge == nullptr) {
1233 return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
1234 " has no non-dead inputs.");
1235 }
1236 state_map_.MarkDead(node);
1237 VLOG(5) << "removing redundant merge: " << node->name();
1238 while (!node->out_edges().empty()) {
1239 const Edge* oe = *node->out_edges().begin();
1240 Node* dst_node = oe->dst();
1241 int dst_port = oe->dst_input();
1242 graph_->RemoveEdge(oe);
1243 graph_->AddEdge(non_dead_edge->src(),
1244 dst_port == Graph::kControlSlot
1245 ? Graph::kControlSlot
1246 : non_dead_edge->src_output(),
1247 dst_node, dst_port);
1248 }
1249 return Status::OK();
1250 }
1251
RemoveRedundantSwitch(Node * node)1252 Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
1253 // Handle redundant switch nodes. A switch node is considered redundant if
1254 // the predicate of the switch already holds on the current branch. E.g., if
1255 // p is the predicate of the switch but p is already known to hold on this
1256 // branch, then the switch can be removed and the dead state propagated
1257 // along one. The checking of predicate is based on the exact predicate
1258 // (rather than boolean equivalence) and aimed at redundant switches as
1259 // currently generated by gradient code.
1260 StateMap::CondId dst_id = state_map_.LookupCondId(node);
1261 if (state_map_.IsDead(dst_id)) return Status::OK();
1262
1263 BranchType b;
1264 OutputTensor pred;
1265 TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
1266
1267 // Determine if we are already on a branch where the switch predicate is
1268 // true/false. Consider both the data and predicate to determine if the
1269 // node is redundant (skipping over identity node).
1270 b = state_map_.FindBranchOf(dst_id, pred);
1271 if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
1272 OutputTensor val;
1273 const Edge* e;
1274 TF_RETURN_IF_ERROR(node->input_edge(0, &e));
1275 val = OutputTensor(e->src(), e->src_output());
1276 while (IsIdentity(val.node)) {
1277 TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
1278 val = OutputTensor(e->src(), e->src_output());
1279 }
1280 b = state_map_.FindBranchOf(dst_id, val);
1281 if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
1282 return Status::OK();
1283 }
1284
1285 VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
1286 << DebugString(dst_id);
1287 const Edge* value_edge;
1288 TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
1289 Node* val_node = value_edge->src();
1290 int val_port = value_edge->src_output();
1291 while (!node->out_edges().empty()) {
1292 auto e = *node->out_edges().begin();
1293 Node* dst_node = e->dst();
1294 int dst_input = e->dst_input();
1295 int switch_branch = e->src_output();
1296 graph_->RemoveEdge(e);
1297 if (switch_branch == Graph::kControlSlot) {
1298 if (IsMerge(dst_node)) {
1299 auto id_or = JoinCondStatesMerge(dst_node, dst_id,
1300 state_map_.LookupCondId(dst_node));
1301 TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1302 FormatNodeForError(*dst_node));
1303 state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1304 } else {
1305 auto id_or =
1306 JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
1307 TF_RETURN_IF_ERROR(id_or.status());
1308 state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1309 }
1310 } else if (BranchType(switch_branch) != b) {
1311 state_map_.MarkDead(dst_node);
1312 continue;
1313 }
1314 graph_->AddEdge(
1315 val_node,
1316 switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
1317 dst_node, dst_input);
1318 }
1319 return Status::OK();
1320 }
1321
DetermineStates(std::vector<Node * > rev_topo_order)1322 Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
1323 // The state that is propagated along the given edge.
1324 for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
1325 Node* dst = *it;
1326 TF_RETURN_IF_ERROR(DetermineCondState(dst));
1327 TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
1328 if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
1329 if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
1330
1331 VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
1332 << " @ " << state_map_.AncestorStateToString(dst);
1333 if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
1334 }
1335 return Status::OK();
1336 }
1337
DetermineAncestorState(Node * dst)1338 Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
1339 StateMap::AncestorId id = nullptr;
1340 StateMap::AncestorState state;
1341
1342 auto insert = [&](StateMap::AncestorId id, Node* src) {
1343 auto other_id = state_map_.LookupAncestorId(src);
1344 if (other_id != id && other_id != nullptr) {
1345 state.insert(other_id->begin(), other_id->end());
1346 }
1347 if (IsMerge(src)) {
1348 state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
1349 } else if (IsSwitch(src)) {
1350 OutputTensor pred;
1351 // For dead switch nodes, GetSwitchPredicate() will fail, and we use
1352 // the switch node directly as ancestor.
1353 if (GetSwitchPredicate(*src, &pred).ok()) {
1354 state.insert({pred, AncestorNode::AncestorNodeType::kPred});
1355 } else {
1356 state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
1357 }
1358 }
1359 return state_map_.GetAncestorId(state);
1360 };
1361
1362 // Compute the union of all the switch/merge nodes that affects the input of
1363 // dst.
1364 for (auto e : dst->in_edges()) {
1365 Node* src = e->src();
1366 id = insert(id, src);
1367 }
1368 state_map_.ResetAncestorId(dst, id);
1369 return Status::OK();
1370 }
1371
DeleteReachableAndDeadNodes(const std::vector<Node * > & merge_order)1372 void FunctionalizeCond::DeleteReachableAndDeadNodes(
1373 const std::vector<Node*>& merge_order) {
1374 // Delete all nodes that have been extracted or are reachable from
1375 // deleted/dead nodes. The input and outgoing edges should have already been
1376 // removed.
1377 std::deque<int> delete_nodes;
1378 std::vector<bool> deleted(graph_->num_node_ids(), false);
1379 // Don't try to delete source or sink nodes.
1380 deleted[graph_->kSourceId] = true;
1381 deleted[graph_->kSinkId] = true;
1382
1383 // All remaining switch nodes that were not excluded from functionalization
1384 // according to `node_filter_` are not reachable from a merge node and
1385 // removed. This is to account for dead switch nodes.
1386 for (int s_id : switch_ids_) {
1387 Node* s = graph_->FindNodeId(s_id);
1388 if (s == nullptr) continue;
1389 for (const Edge* e : s->out_edges()) {
1390 // Control outputs of switch nodes (which are unconditionally executed if
1391 // the switch is) are not removed as they need not be part of a
1392 // conditional.
1393 if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1394 }
1395 // Only remove switch node if we have functionalized the corresponding
1396 // condition before (according to `node_filter_`).
1397 if (!node_filter_ || node_filter_(s)) {
1398 VLOG(2) << "Removing obsolete switch node " << s->name();
1399 deleted[s_id] = true;
1400 graph_->RemoveNode(s);
1401 }
1402 }
1403
1404 // All merge nodes that were not excluded from functionalization according to
1405 // `node_filter_` should have been transformed at this point and we remove
1406 // them from the graph here.
1407 for (Node* m : merge_order) {
1408 for (const Edge* e : m->out_edges()) {
1409 // Similar to control outputs of switch nodes don't remove control
1410 // outputs of merge nodes.
1411 // TODO(jpienaar): Check cases where output edges still exist here vs
1412 // being removed in AddOutputEdges.
1413 if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1414 }
1415 // Only remove merge node if we have functionalized the corresponding
1416 // condition before (according to `node_filter_`).
1417 if (!node_filter_ || node_filter_(m)) {
1418 VLOG(2) << "Removing obsolete merge node " << m->name();
1419 deleted[m->id()] = true;
1420 graph_->RemoveNode(m);
1421 }
1422 }
1423
1424 // Enqueue all the dead nodes.
1425 for (Node* n : graph_->nodes()) {
1426 if (state_map_.IsDead(state_map_.LookupCondId(n))) {
1427 delete_nodes.push_back(n->id());
1428 }
1429 }
1430 // Remove dead nodes and nodes that are reachable from dead nodes.
1431 while (!delete_nodes.empty()) {
1432 int d_id = delete_nodes.front();
1433 delete_nodes.pop_front();
1434 if (deleted[d_id]) continue;
1435 Node* d = graph_->FindNodeId(d_id);
1436 // Switch and Merge nodes could have been deleted already.
1437 if (d == nullptr) continue;
1438 for (const Edge* e : d->out_edges()) {
1439 delete_nodes.push_back(e->dst()->id());
1440 }
1441 VLOG(2) << "Removing obsolete node " << d->name();
1442 deleted[d_id] = true;
1443 graph_->RemoveNode(d);
1444 }
1445 }
1446
SortMergeNodes(std::vector<Node * > * merge_order)1447 void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
1448 // Sort merge nodes by nesting depth.
1449 using sort_pair = std::pair<int, Node*>;
1450 std::vector<sort_pair> inner_to_outer_merge_order;
1451 inner_to_outer_merge_order.reserve(merge_order->size());
1452 for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
1453 Node* merge = *it;
1454 StateMap::CondId id = state_map_.LookupCondId(merge);
1455 int depth = id != nullptr ? id->size() : 0;
1456 inner_to_outer_merge_order.emplace_back(depth, merge);
1457 }
1458 std::stable_sort(
1459 inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
1460 [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
1461 merge_order->clear();
1462 for (sort_pair t : inner_to_outer_merge_order) {
1463 merge_order->push_back(t.second);
1464 }
1465 }
1466
FunctionalizeInternal()1467 Status FunctionalizeCond::FunctionalizeInternal() {
1468 // The general approach for converting a tf.cond (as lowered via switch/merge
1469 // nodes) to a functional if is as follows:
1470 // 1. Determine the topological order and collect all the switch and merge
1471 // nodes in the graph;
1472 // 2. Compute the predicates and dominance structure for all the nodes in the
1473 // graph - this includes which predicate must be true for a op to execute
1474 // (predicate values are considered directly rather than attempting to
1475 // determine deeper equivalence). We shall refer to this structure as the
1476 // CondState;
1477 // 3. Sort the merge nodes by nesting depth;
1478 // 4. Extract merge nodes together that have the same CondState and
1479 // AncestorState from the innermost to the outermost into IfOps;
1480 // Note: In the above only nodes that feed into a merge node will be
1481 // considered for functionalization.
1482 // Note: Nodes for which `node_filter_` returns false are excluded.
1483
1484 // Perform a DFS over the graph and
1485 // * Determine the reverse topological order of the nodes (there should be no
1486 // cycles at this point so the post-order numbering corresponds to the
1487 // reverse topological sorting);
1488 // * Record reverse topological for merge and switch nodes;
1489 std::vector<Node*> rev_topo_order;
1490 std::vector<Node*> merge_order;
1491 DFS(*graph_, nullptr, [&](Node* n) {
1492 // Only collect switch and merge nodes that are not filtered out, those form
1493 // the conditions that will be functionalized.
1494 if (!node_filter_ || node_filter_(n)) {
1495 if (IsSwitch(n)) {
1496 AddSwitchId(n->id());
1497 }
1498 if (IsMerge(n)) {
1499 merge_order.push_back(n);
1500 }
1501 }
1502 // Collect all other nodes here, independent of `node_filter_`, because they
1503 // might belong to a condition that should be functionalized.
1504 if (n->IsOp()) {
1505 rev_topo_order.push_back(n);
1506 }
1507 });
1508
1509 // No merges to functionalize.
1510 if (merge_order.empty()) {
1511 // No merges mean no switch values consumed (as only considering values
1512 // fetchable as output of merge);
1513 DeleteReachableAndDeadNodes(merge_order);
1514 return Status::OK();
1515 }
1516
1517 TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
1518 if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
1519
1520 // Determine the shapes of the ops in the graph.
1521 ShapeRefiner shape_refiner{graph_->versions().producer(),
1522 graph_->op_registry()};
1523 std::vector<Node*> nodes;
1524 GetReversePostOrder(*graph_, &nodes);
1525 for (auto node : nodes) {
1526 if (!shape_refiner.AddNode(node).ok()) {
1527 LOG(WARNING) << "Couldn't deduce shape for " << node->name();
1528 }
1529 }
1530
1531 // Sort the merge nodes from innermost outwards.
1532 SortMergeNodes(&merge_order);
1533
1534 // Cluster merge nodes by (CondId, AncestorId, predicate) in order of
1535 // nesting. (CondId, AncestorId) is not enough, e.g.
1536 // pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
1537 // pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
1538 // cond1 = control_flow_ops.cond(pred1, ...)
1539 // cond2 = control_flow_ops.cond(pred2, ...)
1540 // cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
1541 // cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
1542 // cond3 and cond4 have the same (CondId, AncestorId), but they should not
1543 // be merged into one "If" node (because they have different predicates).
1544 std::deque<std::vector<Node*>> merge_clusters;
1545 std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
1546 for (Node* merge : merge_order) {
1547 auto cond_id = state_map_.LookupCondId(merge);
1548 if (state_map_.IsDead(cond_id)) continue;
1549
1550 auto predicate = merge_to_predicate_.find(merge);
1551 if (predicate == merge_to_predicate_.end()) {
1552 return errors::Internal("Cannot find predicate for Merge node ",
1553 merge->name());
1554 }
1555
1556 ClusterTuple key = std::make_tuple(
1557 cond_id, state_map_.LookupAncestorId(merge), predicate->second);
1558 auto idx = merge_cluster_index.find(key);
1559 if (idx == merge_cluster_index.end()) {
1560 merge_cluster_index[key] = merge_clusters.size();
1561 merge_clusters.push_back({merge});
1562 } else {
1563 merge_clusters[idx->second].emplace_back(merge);
1564 }
1565 }
1566
1567 // Extract the conditionals from inner most to outer most. Extracting from
1568 // innermost to outermost enables the extraction pass to stop once it
1569 // encounters a Switch node instead of having to keep track of Switch/Merge
1570 // nodes seen.
1571 for (const auto& cluster : merge_clusters) {
1572 // Construct a Conditional with the predicate of the merge.
1573 Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_,
1574 shape_refiner);
1575 for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
1576 TF_RETURN_IF_ERROR(
1577 cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
1578
1579 if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
1580 }
1581
1582 DeleteReachableAndDeadNodes(merge_order);
1583
1584 return Status::OK();
1585 }
1586
DumpGraphWithCondState(const string & name)1587 void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
1588 const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
1589
1590 for (Node* n : graph_->nodes()) {
1591 n->ClearAttr(kCondGroupDebugAttr);
1592 n->AddAttr(kCondGroupDebugAttr,
1593 absl::StrCat(state_map_.CondStateToString(n), "_",
1594 state_map_.AncestorStateToString(n)));
1595 }
1596 LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
1597 << DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
1598 *graph_, library_);
1599 }
1600
AddSwitchId(int switch_id)1601 void FunctionalizeCond::AddSwitchId(int switch_id) {
1602 switch_ids_.push_back(switch_id);
1603 }
1604
Functionalize(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1605 Status FunctionalizeCond::Functionalize(Graph* graph,
1606 FunctionLibraryDefinition* library,
1607 const NodeFilter& node_filter) {
1608 VLOG(1) << "FunctionalizeCond::Functionalize";
1609 FunctionalizeCond fc(graph, library, node_filter);
1610 return fc.FunctionalizeInternal();
1611 }
1612
1613 } // namespace functionalize_cond
1614
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1615 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
1616 const NodeFilter& node_filter) {
1617 // FunctionalizeControlFlow is invoked for every function, so the loops's
1618 // bodies and conditionals that were extracted into functions will be handled
1619 // in successive invocations.
1620 return functionalize_cond::FunctionalizeCond::Functionalize(graph, library,
1621 node_filter);
1622 }
1623
1624 } // namespace tensorflow
1625