• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_join.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
29 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
30 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/shape_refiner.h"
34 #include "tensorflow/core/framework/graph_to_functiondef.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/control_flow.h"
39 #include "tensorflow/core/graph/node_builder.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/strcat.h"
43 #include "tensorflow/core/util/dump_graph.h"
44 
45 using xla::StatusOr;
46 
47 namespace tensorflow {
48 namespace functionalize_cond {
49 
operator <(const AncestorNode & other) const50 bool AncestorNode::operator<(const AncestorNode& other) const {
51   return (output_tensor.node->id() < other.output_tensor.node->id()) ||
52          (output_tensor.node->id() == other.output_tensor.node->id() &&
53           output_tensor.index < other.output_tensor.index) ||
54          (output_tensor.node->id() == other.output_tensor.node->id() &&
55           output_tensor.index == other.output_tensor.index &&
56           type < other.type);
57 }
58 
operator ==(const AncestorNode & other) const59 bool AncestorNode::operator==(const AncestorNode& other) const {
60   return output_tensor.node->id() == other.output_tensor.node->id() &&
61          output_tensor.index == other.output_tensor.index && type == other.type;
62 }
63 
operator ()(const AncestorNode & ancestor) const64 size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
65   size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
66   h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
67   return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
68 }
69 
70 typedef std::tuple<StateMap::CondId, StateMap::AncestorId, OutputTensor>
71     ClusterTuple;
72 
73 struct ClusterTupleLessThan {
operator ()tensorflow::functionalize_cond::ClusterTupleLessThan74   bool operator()(const ClusterTuple& a, const ClusterTuple& b) const {
75     if (std::tie(std::get<0>(a), std::get<1>(a)) <
76         std::tie(std::get<0>(b), std::get<1>(b))) {
77       return true;
78     } else if (std::tie(std::get<0>(a), std::get<1>(a)) ==
79                std::tie(std::get<0>(b), std::get<1>(b))) {
80       return StateMap::OutputTensorLess()(std::get<2>(a), std::get<2>(b));
81     } else {
82       return false;
83     }
84   }
85 };
86 
87 // TODO(jpienaar): Move to OutputTensor.
DebugString(const OutputTensor & tensor)88 string DebugString(const OutputTensor& tensor) {
89   return absl::StrCat(tensor.node->name(), ":", tensor.index);
90 }
91 
Branch_Name(BranchType b)92 string Branch_Name(BranchType b) {
93   switch (b) {
94     case BranchType::kElseBranch:
95       return "else";
96     case BranchType::kThenBranch:
97       return "then";
98     case BranchType::kBoth:
99       return "both";
100     case BranchType::kNeither:
101       return "neither";
102   }
103 }
104 
DebugString(StateMap::CondId cond_state)105 string DebugString(StateMap::CondId cond_state) {
106   if (cond_state == nullptr || cond_state->empty()) return "{}";
107   using value_type = StateMap::CondState::value_type;
108   return absl::StrCat(
109       "{",
110       absl::StrJoin(*cond_state, ", ",
111                     [](string* output, const value_type& pred_branch) {
112                       const OutputTensor& pred = pred_branch.first;
113                       const BranchType& branch = pred_branch.second;
114                       if (branch == BranchType::kNeither)
115                         absl::StrAppend(output, "d");
116                       else
117                         absl::StrAppend(output, "s(", DebugString(pred), ",",
118                                         Branch_Name(branch), ")");
119                     }),
120       "}");
121 }
122 
123 // Returns the predicate of a switch.
GetSwitchPredicate(const Node & switch_node,OutputTensor * pred)124 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
125   const Edge* pred_edge;
126   TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge));
127   // The predicate can be preceded by a identity node. Look through
128   // identity nodes to predicate.
129   while (pred_edge->src()->IsIdentity()) {
130     TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge));
131   }
132   *pred = OutputTensor(pred_edge->src(), pred_edge->src_output());
133   return Status::OK();
134 }
135 
GetSwitchValue(const Node & switch_node,OutputTensor * val)136 Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
137   const Edge* val_edge;
138   TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
139   *val = OutputTensor(val_edge->src(), val_edge->src_output());
140   return Status::OK();
141 }
142 
operator ()(const OutputTensor & lhs,const OutputTensor & rhs) const143 bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
144                                             const OutputTensor& rhs) const {
145   return (lhs.node->id() < rhs.node->id()) ||
146          (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
147 }
148 
149 struct CondStateLess {
operator ()tensorflow::functionalize_cond::CondStateLess150   bool operator()(const StateMap::CondState::value_type& lhs,
151                   const StateMap::CondState::value_type& rhs) const {
152     if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
153       return true;
154     if (lhs.first.node->id() == rhs.first.node->id() &&
155         lhs.first.index == rhs.first.index)
156       return lhs.second < rhs.second;
157     return false;
158   }
159 };
160 
StateMap(Graph * graph)161 StateMap::StateMap(Graph* graph) {
162   node_to_condid_map_.resize(graph->num_node_ids());
163   node_to_ancestorid_map_.resize(graph->num_node_ids());
164   // Initialize the dead state (empty state is designated with a nullptr).
165   dead_id_ = GetCondId(
166       {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
167 }
168 
IsDead(StateMap::CondId id) const169 bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
170 
IsEmpty(StateMap::CondId id) const171 bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
172 
operator ()(const StateMap::CondState & map) const173 size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
174   if (map.empty()) return 0;
175   // Compute hash of the front element.
176   auto it = map.begin();
177   size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
178                            hash<BranchType>()(it->second));
179   for (++it; it != map.end(); ++it) {
180     // Combine the has with the different elements in the map.
181     h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
182                                        hash<BranchType>()(it->second)));
183   }
184   return h;
185 }
186 
operator ()(const StateMap::AncestorState & map) const187 size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
188   if (map.empty()) return 0;
189   // Compute hash of the front element.
190   auto it = map.begin();
191   size_t h = AncestorNode::Hash()(*it);
192   for (++it; it != map.end(); ++it) {
193     // Combine the has with the different elements in the map.
194     h = Hash64Combine(h, AncestorNode::Hash()(*it));
195   }
196   return h;
197 }
198 
199 // CondArgNode represents a input to the conditional and its corresponding
200 // switch nodes.
201 struct CondArgNode {
CondArgNodetensorflow::functionalize_cond::CondArgNode202   explicit CondArgNode(Node* src, int src_output)
203       : src(src), src_output(src_output) {}
204 
ToStringtensorflow::functionalize_cond::CondArgNode205   string ToString() const {
206     return absl::StrCat("src=", src->name(), ":", src_output,
207                         " switches=", NodesToString(switches));
208   }
209 
210   Node* src;
211   int src_output;
212   std::array<Node*, 2> branch_copy;
213   std::vector<Node*> switches;
214 };
215 using CondArgNodes = std::vector<CondArgNode>;
216 
DebugString(const CondArgNodes & nodes)217 string DebugString(const CondArgNodes& nodes) {
218   return absl::StrCat(
219       "[",
220       absl::StrJoin(nodes, ", ",
221                     [](string* output, const CondArgNode& node) {
222                       absl::StrAppend(output, node.ToString());
223                     }),
224       "]");
225 }
226 
LookupCondId(const Node * node) const227 StateMap::CondId StateMap::LookupCondId(const Node* node) const {
228   const int64 map_size = node_to_condid_map_.size();
229   if (node->id() < map_size) return node_to_condid_map_[node->id()];
230   return added_node_condid_mapping_.at(node->id());
231 }
232 
GetCondId(const StateMap::CondState & state)233 StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
234   if (state.empty()) return nullptr;
235   return &*condstate_set_.insert(state).first;
236 }
237 
ResetCondId(const Node * node,StateMap::CondId id)238 void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
239   const int64 map_size = node_to_condid_map_.size();
240   if (node->id() < map_size)
241     node_to_condid_map_[node->id()] = id;
242   else
243     added_node_condid_mapping_[node->id()] = id;
244 }
245 
LookupAncestorId(const Node * node) const246 StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
247   const int64 map_size = node_to_ancestorid_map_.size();
248   if (node->id() < map_size) return node_to_ancestorid_map_[node->id()];
249   return added_node_ancestorid_mapping_.at(node->id());
250 }
251 
GetAncestorId(const StateMap::AncestorState & state)252 StateMap::AncestorId StateMap::GetAncestorId(
253     const StateMap::AncestorState& state) {
254   if (state.empty()) return nullptr;
255   return &*ancestorstate_set_.insert(state).first;
256 }
257 
ResetAncestorId(const Node * node,StateMap::AncestorId id)258 void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
259   const int64 map_size = node_to_ancestorid_map_.size();
260   if (node->id() < map_size)
261     node_to_ancestorid_map_[node->id()] = id;
262   else
263     added_node_ancestorid_mapping_[node->id()] = id;
264 }
265 
MarkDead(const Node * node)266 void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
267 
CondStateToString(const Node * node) const268 string StateMap::CondStateToString(const Node* node) const {
269   return CondStateToString(LookupCondId(node));
270 }
271 
CondStateToString(StateMap::CondId id) const272 string StateMap::CondStateToString(StateMap::CondId id) const {
273   return DebugString(id);
274 }
275 
AncestorStateToString(const Node * node) const276 string StateMap::AncestorStateToString(const Node* node) const {
277   if (auto id = LookupAncestorId(node)) {
278     return absl::StrCat(
279         "{",
280         absl::StrJoin(*id, ",",
281                       [](string* output, const AncestorNode& ancestor) {
282                         absl::StrAppend(output,
283                                         ancestor.output_tensor.node->name(),
284                                         ":", ancestor.output_tensor.index);
285                       }),
286         "}");
287   }
288   return "{}";
289 }
290 
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)291 FunctionalizeCond::FunctionalizeCond(Graph* graph,
292                                      FunctionLibraryDefinition* library,
293                                      const NodeFilter& node_filter)
294     : state_map_(graph),
295       library_(library),
296       graph_(graph),
297       node_filter_(node_filter) {}
298 
299 // Class representing the merge/switch nodes that will become a conditional.
300 class Conditional {
301  public:
302   Conditional(OutputTensor predicate, FunctionalizeCond* parent,
303               StateMap* cond_state_map, const ShapeRefiner& refiner);
304 
305   // Adds merge node that is part of this conditional.
306   Status AddMerge(Node* m);
307 
308   // Constructs an If node from the merge nodes.
309   Status BuildAndReplace(
310       Graph* graph, FunctionLibraryDefinition* library,
311       std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
312 
313  private:
314   // Extracts the then/else bodies: creates new graphs with the nodes
315   // corresponding to the nodes in the then/else branches as of this conditional
316   // as function bodies.
317   Status ExtractBodies(Graph* graph);
318 
319   // Builds the arguments that are the input to the If.
320   Status BuildArgumentNodes();
321 
322   // Builds the If node for the extracted bodies with the given predicate.
323   Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library);
324 
325   // Adds input edges to If node.
326   Status AddInputEdges(
327       Graph* graph,
328       const std::unordered_map<Node*, OutputTensor>& merge_to_replacement);
329 
330   // Adds output edges from If node.
331   // Record new output tensor for all Merge nodes in 'merge_to_replacement'.
332   Status AddOutputEdges(
333       Graph* graph,
334       std::unordered_map<Node*, OutputTensor>* merge_to_replacement);
335 
336   // Adds switch node that is part of this conditional.
337   Status AddSwitch(Node* s);
338 
339   // Adds a switch node along the edge and rewire the edge to go via the switch.
340   Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
341                                 Graph* graph);
342 
343   // Internal name of conditional. The name is based on the first merge node
344   // added.
345   string name() const;
346 
347   // The FunctionalizeCond instance that created this.
348   FunctionalizeCond* parent_;
349 
350   // Mapping between nodes and their cond state.
351   StateMap* state_map_;
352 
353   // The predicate of the conditional.
354   OutputTensor predicate_;
355 
356   // Shape refiner of ops in the graph.
357   const ShapeRefiner& refiner_;
358 
359   // The predicate of the switches of the conditional. This may be different
360   // than predicate (which is initialized from the original graph) as the
361   // predicate could be the output of a newly created If node.
362   OutputTensor switch_predicate_;
363 
364   // Switch nodes in graph that are part of this conditional.
365   std::set<Node*, NodeCmpByNameResourcesLast> switches_;
366 
367   // Merge nodes in graph that are part of this conditional.
368   std::set<Node*, NodeCmpByNameResourcesLast> merges_;
369 
370   // Vector of control inputs from outside the conditional to a node inside.
371   std::vector<Node*> external_control_inputs_;
372   std::vector<Node*> external_control_outputs_;
373 
374   // Graphs corresponding to the then and else branch.
375   std::array<std::unique_ptr<Graph>, 2> bodies_;
376 
377   // Maps from graph_ to the branch body's graph.
378   std::array<std::vector<Node*>, 2> node_maps_;
379 
380   // The argument nodes created for the switches.
381   CondArgNodes cond_arg_nodes_;
382 
383   // The constructed If node.
384   Node* if_node_ = nullptr;
385 
386   // Whether the merge nodes of this conditional have been replaced.
387   bool replaced_ = false;
388 };
389 
Conditional(OutputTensor predicate,FunctionalizeCond * parent,StateMap * cond_state_map,const ShapeRefiner & refiner)390 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
391                          StateMap* cond_state_map, const ShapeRefiner& refiner)
392     : parent_(parent),
393       state_map_(cond_state_map),
394       predicate_(predicate),
395       refiner_(refiner) {}
396 
AddMerge(Node * m)397 Status Conditional::AddMerge(Node* m) {
398   merges_.insert(m);
399   return Status::OK();
400 }
401 
AddSwitch(Node * s)402 Status Conditional::AddSwitch(Node* s) {
403   VLOG(5) << "Adding switch " << s->DebugString();
404   OutputTensor predicate;
405   TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate));
406   if (switch_predicate_.node == nullptr) switch_predicate_ = predicate;
407   if (!(switch_predicate_ == predicate)) {
408     return errors::InvalidArgument(
409         "Merge nodes ", NodesToString(merges_),
410         " directly dominated by switch nodes with different predicates (",
411         DebugString(switch_predicate_), " vs ", DebugString(predicate), ").");
412   }
413   switches_.insert(s);
414   parent_->AddSwitchId(s->id());
415   return Status::OK();
416 }
417 
BuildArgumentNodes()418 Status Conditional::BuildArgumentNodes() {
419   VLOG(1) << "Build function arguments";
420   struct Hash {
421     size_t operator()(const std::pair<Node*, int>& item) const {
422       return Hash64Combine(hash<Node*>()(item.first),
423                            std::hash<int>()(item.second));
424     }
425   };
426 
427   std::unordered_map<std::pair<Node*, int>, int, Hash> input_index;
428   for (Node* switch_node : switches_) {
429     const Edge* e;
430     TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e));
431     std::pair<Node*, int> key = std::make_pair(e->src(), e->src_output());
432     if (input_index.find(key) == input_index.end()) {
433       input_index[key] = cond_arg_nodes_.size();
434       cond_arg_nodes_.emplace_back(key.first, key.second);
435     }
436     cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node);
437   }
438   VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_);
439 
440   int arg_count = 0;
441   for (CondArgNode& cond_arg_node : cond_arg_nodes_) {
442     DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output);
443     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
444       int branch_index = static_cast<int>(branch);
445       TF_RETURN_IF_ERROR(
446           NodeBuilder(absl::StrCat("_Arg", arg_count),
447                       FunctionLibraryDefinition::kArgOp)
448               .Attr("T", dtype)
449               .Attr("index", arg_count)
450               .Finalize(bodies_[branch_index].get(),
451                         &cond_arg_node.branch_copy[branch_index]));
452     }
453     for (Node* node : cond_arg_node.switches) {
454       for (const Edge* e : node->out_edges()) {
455         if (e->IsControlEdge()) continue;
456         int branch_index = e->src_output();
457         Node* src_copy = cond_arg_node.branch_copy[branch_index];
458         Node* dst_copy = node_maps_[branch_index][e->dst()->id()];
459 
460         // The graph may contain dead switch nodes,
461         if (dst_copy == nullptr) continue;
462 
463         TF_RET_CHECK(dst_copy != nullptr)
464             << "Unable to find copied node for " << e->dst()->DebugString()
465             << " on branch " << Branch_Name(BranchType(branch_index));
466         // If the input goes directly to a merge then the merge has
467         // been replaced by a retval so the dst input is 0 instead of
468         // dst_input.
469         int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input();
470         bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input);
471       }
472     }
473     ++arg_count;
474   }
475 
476   // Verify that all retvals have an input.
477   // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have
478   // input.
479   for (Node* m : merges_) {
480     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
481       bool has_input = false;
482       for (auto e : node_maps_[static_cast<int>(branch)][m->id()]->in_edges()) {
483         if (!e->IsControlEdge()) {
484           has_input = true;
485           break;
486         }
487       }
488       if (!has_input) {
489         return errors::Internal(
490             "Failed to functionalize control flow with merge ",
491             FormatNodeForError(*m), " that doesn't have input on ",
492             Branch_Name(branch), " branch.");
493       }
494     }
495   }
496 
497   return Status::OK();
498 }
499 
AddSwitchNodeAlongEdge(const Edge * edge,BranchType branch,Graph * graph)500 Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
501                                            Graph* graph) {
502   // Previously we had edge:
503   //   src:src_output ---- edge ----> dst:dst_input
504   // post this we have (in graph)
505   //   src:src_output --> switch<pred> --- new_edge --> dst:dst_input
506 
507   // TODO(jpienaar): One could keep a map caching the extra switch nodes added
508   // to avoid adding another switch to feed a value for which a switch was
509   // already added.
510   Node* switch_node;
511   Node* src = edge->src();
512   int src_output = edge->src_output();
513   TF_RETURN_IF_ERROR(
514       NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
515                   "Switch")
516           .Input(src, src_output)
517           .Input(const_cast<Node*>(predicate_.node), predicate_.index)
518           .Finalize(graph, &switch_node));
519   state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
520   state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
521 
522   Node* dst = edge->dst();
523   int dst_input = edge->dst_input();
524   graph->RemoveEdge(edge);
525   graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
526   return AddSwitch(switch_node);
527 }
528 
ExtractBodies(Graph * graph)529 Status Conditional::ExtractBodies(Graph* graph) {
530   VLOG(2) << "Extracting bodies for " << name();
531   for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
532     bodies_[static_cast<int>(b)] =
533         absl::make_unique<Graph>(graph->op_registry());
534   }
535 
536   auto find_branch = [&](const Edge* e) {
537     const auto& id = state_map_->LookupCondId(e->src());
538     return IsSwitch(e->src()) ? BranchType(e->src_output())
539                               : state_map_->FindBranchOf(id, predicate_);
540   };
541 
542   std::array<std::vector<Node*>, 2> stacks;
543   VLOG(5) << "Merges: " << NodesToString(merges_);
544   for (Node* m : merges_) {
545     VLOG(5) << "For merge: " << m->DebugString() << " "
546             << state_map_->CondStateToString(m);
547     for (auto e : m->in_edges()) {
548       if (e->IsControlEdge()) continue;
549       BranchType branch = find_branch(e);
550       TF_RET_CHECK(branch == BranchType::kThenBranch ||
551                    branch == BranchType::kElseBranch)
552           << "Error: " << e->src()->name()
553           << " is not on either then or else branch (" << Branch_Name(branch)
554           << ") for predicate " << DebugString(predicate_) << " ["
555           << DebugString(state_map_->LookupCondId(e->src())) << "].";
556       Node* src = e->src();
557       if (IsSwitch(src)) {
558         // Switch node outputs and dependencies are handled separately.
559         TF_RETURN_IF_ERROR(AddSwitch(src));
560       } else {
561         stacks[static_cast<int>(branch)].push_back(src);
562       }
563     }
564   }
565 
566   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
567     int branch_index = static_cast<int>(branch);
568     auto output = bodies_[branch_index].get();
569     auto& stack = stacks[branch_index];
570     VLOG(5) << "In branch: " << Branch_Name(branch) << " "
571             << NodesToString(stack);
572     std::vector<bool> visited(graph->num_node_ids(), false);
573     node_maps_[branch_index].resize(graph->num_node_ids(), nullptr);
574     auto& node_map = node_maps_[branch_index];
575 
576     while (!stack.empty()) {
577       Node* n = stack.back();
578       stack.pop_back();
579 
580       if (visited.at(n->id())) continue;
581       visited[n->id()] = true;
582 
583       // Verify output edges and record control edges exiting scope.
584       for (const Edge* e : n->out_edges()) {
585         Node* dst = e->dst();
586         if (IsMerge(dst)) continue;
587         Node* src = e->src();
588 
589         auto dst_id = state_map_->LookupCondId(dst);
590         auto src_id = state_map_->LookupCondId(src);
591         if (dst_id != src_id) {
592           if (e->IsControlEdge()) {
593             external_control_outputs_.push_back(e->src());
594           } else {
595             // Constants are treated specially to workaround the case of
596             // non-dominated constant nodes.
597             if (!IsConstant(src)) {
598               // TODO(b/78882471): A node that feeds into two different
599               // CondState is not necessarily an error so log a warning for now
600               // but revisit to improve the testing to enable making this an
601               // error.
602               LOG(WARNING) << errors::InvalidArgument(
603                   "Graph contains node ", FormatNodeForError(*src),
604                   " that feeds into node ", FormatNodeForError(*dst),
605                   " but these nodes are in different control contexts (",
606                   DebugString(src_id), " vs ", DebugString(dst_id),
607                   " (detected during out edge testing)");
608             }
609           }
610         }
611       }
612 
613       // Copying incoming edges to dst node. Iterate over a copy of the edges
614       // as they could be mutated during iteration.
615       std::vector<const Edge*> in_edges(n->in_edges().begin(),
616                                         n->in_edges().end());
617       // Sort in_edges to make sure nodes are copied in a deterministic order.
618       std::sort(
619           in_edges.begin(), in_edges.end(), [](const Edge* a, const Edge* b) {
620             int a_src_output = a->src_output(), b_src_output = b->src_output();
621             StringPiece a_name(a->src()->name()), b_name(b->src()->name());
622             return std::tie(a_src_output, a_name) <
623                    std::tie(b_src_output, b_name);
624           });
625       for (const Edge* e : in_edges) {
626         Node* src = e->src();
627         // Skip src/dst node.
628         if (!src->IsOp()) continue;
629 
630         Node* dst = e->dst();
631         if (IsSwitch(src)) {
632           // Switch node outputs and dependencies are handled separately.
633           TF_RETURN_IF_ERROR(AddSwitch(src));
634           continue;
635         }
636 
637         // Verify input is from the same context.
638         auto src_id = state_map_->LookupCondId(src);
639         auto dst_id = state_map_->LookupCondId(dst);
640         if (IsMerge(dst) || src_id == dst_id) {
641           // TODO(jpienaar): The merge case can be more strict.
642           if (node_map.at(src->id()) == nullptr) {
643             node_map.at(src->id()) = output->CopyNode(src);
644             stack.push_back(src);
645           }
646         } else if (e->IsControlEdge()) {
647           // Here we have a control flow edge between src and dst that are not
648           // in the same context. This is an external control dependency except
649           // for one case: where the only difference between CondId of e->src()
650           // and CondId of e->dst() is that e->src() has {PRED, kNeither} and
651           // e->dst() has {PRED, kThenBranch/kElseBranch}. This happens in
652           // gradients code for tf.cond(), where e->src() is a control pivot
653           // node for a branch and e->dst() is a data node in that branch.
654           bool is_external_control_input = true;
655           if (!state_map_->IsEmpty(src_id) && !state_map_->IsEmpty(dst_id)) {
656             std::vector<StateMap::CondState::value_type> diff;
657             std::set_symmetric_difference(
658                 src_id->begin(), src_id->end(), dst_id->begin(), dst_id->end(),
659                 std::back_inserter(diff), CondStateLess());
660             if (diff.size() == 2 && diff[0].first == diff[1].first &&
661                 (diff[0].second == BranchType::kNeither ||
662                  diff[1].second == BranchType::kNeither)) {
663               auto src_branch = src_id->find(diff[0].first);
664               if (src_branch != src_id->end() &&
665                   src_branch->second == BranchType::kNeither) {
666                 is_external_control_input = false;
667               }
668             }
669           }
670           if (is_external_control_input) {
671             external_control_inputs_.push_back(src);
672           }
673         } else {
674           // This shouldn't happen, this means we have an external data input
675           // not entering via a switch node. Work around this by for
676           // * constant nodes copy them;
677           // * non-constant nodes, insert a switch along the edge;
678           if (IsConstant(src)) {
679             // Check if constant node was added already. It is possible to have
680             // multiple uses of a constant node.
681             if (node_map.at(src->id()) == nullptr) {
682               node_map.at(src->id()) = output->CopyNode(src);
683             }
684           } else {
685             StateMap::CondState state = *dst_id;
686             state.erase(predicate_);
687             if (state_map_->GetCondId(state) == src_id) {
688               TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
689               continue;
690             } else {
691               return errors::InvalidArgument(
692                   "Graph contains node ", FormatNodeForError(*src),
693                   " that feeds into node ", FormatNodeForError(*dst),
694                   " but these nodes are in different control contexts (",
695                   DebugString(src_id), " vs ", DebugString(dst_id),
696                   " (detected during in edge testing)");
697             }
698           }
699         }
700 
701         Node* src_copy = node_map.at(e->src()->id());
702         int src_output = e->src_output();
703         if (node_map.at(dst->id()) == nullptr) {
704           node_map.at(dst->id()) = output->CopyNode(dst);
705         }
706         Node* dst_copy = node_map.at(e->dst()->id());
707         if (e->IsControlEdge()) {
708           // Skip control inputs from external context.
709           if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy);
710         } else {
711           output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
712         }
713       }
714     }
715   }
716 
717   // Build return values from the merge nodes.
718   int index = 0;
719   for (Node* m : merges_) {
720     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
721       int branch_index = static_cast<int>(branch);
722       auto& node_map = node_maps_[branch_index];
723       auto output = bodies_[branch_index].get();
724       TF_ASSIGN_OR_RETURN(node_map[m->id()],
725                           BuildRetvalNode(output, m->output_type(0), index));
726     }
727     ++index;
728 
729     // Connect the input to the merge_ with the retval, except if it is a
730     // Switch node, which is handled separately.
731     for (auto e : m->in_edges()) {
732       if (e->IsControlEdge()) continue;
733       int branch_index = static_cast<int>(find_branch(e));
734       auto& node_map = node_maps_[branch_index];
735       auto output = bodies_[branch_index].get();
736       Node* in = e->src();
737       if (!IsSwitch(in)) {
738         if (node_map.at(in->id()) == nullptr) {
739           node_map[in->id()] = output->CopyNode(in);
740         }
741         output->AddEdge(node_map[in->id()], e->src_output(),
742                         node_map.at(m->id()), 0);
743       }
744     }
745   }
746   return Status::OK();
747 }
748 
BuildIfNode(Graph * graph,FunctionLibraryDefinition * library)749 Status Conditional::BuildIfNode(Graph* graph,
750                                 FunctionLibraryDefinition* library) {
751   VLOG(2) << "Build cond function for " << name();
752   NodeDebugInfo debug_info((*merges_.begin())->def());
753   NodeDefBuilder builder(name(), "If", library, &debug_info);
754   const string branch_name[] = {"else_branch", "then_branch"};
755   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
756     int branch_index = static_cast<int>(branch);
757 
758     NameAttrList body_name;
759     body_name.set_name(library->UniqueFunctionName(
760         absl::StrCat("_functionalize_if_", branch_name[branch_index], "_")));
761 
762     VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
763             << "): "
764             << DumpGraphToFile(
765                    "functionalize_cond_body_" + branch_name[branch_index],
766                    *bodies_[branch_index], nullptr);
767 
768     FunctionDef body_fdef;
769     TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index],
770                                           body_name.name(), &body_fdef));
771     TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
772     builder.Attr(branch_name[branch_index], body_name);
773   }
774 
775   VLOG(3) << "Build input type";
776   std::vector<NodeDefBuilder::NodeOut> inputs;
777   DataTypeVector in_arg_types;
778   for (auto& kv : cond_arg_nodes_) {
779     bool inserted = false;
780     for (const Node* arg : kv.switches) {
781       const Edge* in_edge;
782       TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
783       if (in_edge->IsControlEdge()) {
784         builder.ControlInput(in_edge->src()->name());
785       } else {
786         if (!inserted) {
787           DataType dtype = arg->input_type(0);
788           inputs.emplace_back(NodeDefBuilder::NodeOut(
789               in_edge->src()->name(), in_edge->src_output(), dtype));
790           in_arg_types.push_back(dtype);
791           inserted = true;
792         }
793       }
794     }
795   }
796   builder.Attr("Tin", in_arg_types);
797 
798   DataTypeVector out_type;
799   std::vector<PartialTensorShape> output_shapes;
800   output_shapes.reserve(merges_.size());
801   for (const Node* merge : merges_) {
802     DataType dtype = merge->output_type(0);
803     TensorShapeProto shape;
804     if (auto* shape_ctx = refiner_.GetContext(merge)) {
805       shape_inference::ShapeHandle handle;
806       shape_ctx->ShapeHandleToProto(shape_ctx->output(0), &shape);
807     }
808     out_type.push_back(dtype);
809     output_shapes.push_back(shape);
810   }
811   builder.Attr("Tout", out_type);
812   VLOG(3) << "Build output type: " << DataTypeVectorString(out_type);
813   builder.Attr("output_shapes", output_shapes);
814   VLOG(3) << "Build output shapes: "
815           << PartialTensorShapeUtils::PartialShapeListString(output_shapes);
816 
817   builder.Attr("Tcond", DT_BOOL);
818   // Add some internal attributes which need to be propagated.
819   // TODO(b/160275126): attributes shouldn't be hard-coded here
820   for (const char* attr_name :
821        {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
822         kTpuReplicateAttrName}) {
823     string attr_val;
824     if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) {
825       builder.Attr(attr_name, attr_val);
826     }
827   }
828   builder.Device(predicate_.node->assigned_device_name());
829   // Conditional should be the first input ...
830   builder.Input(
831       NodeDefBuilder::NodeOut(predicate_.node->name(), predicate_.index,
832                               predicate_.node->output_type(predicate_.index)));
833   // ... followed by the other inputs.
834   builder.Input(inputs);
835 
836   VLOG(3) << "Build If node";
837   NodeDef if_def;
838   TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
839   TF_ASSIGN_OR_RETURN(if_node_,
840                       parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
841 
842   return Status::OK();
843 }
844 
AddInputEdges(Graph * graph,const std::unordered_map<Node *,OutputTensor> & merge_to_replacement)845 Status Conditional::AddInputEdges(
846     Graph* graph,
847     const std::unordered_map<Node*, OutputTensor>& merge_to_replacement) {
848   VLOG(2) << "AddInputEdges for " << if_node_->name();
849   int index = 0;
850   // Add predicate input.
851   if (predicate_.node->IsMerge()) {
852     // If the predicate is a Merge node, we should not use Merge output as
853     // predicate. Instead, we should use the corresponding If output in
854     // 'merge_to_replacement'. Otherwise, this Conditional's If node is still
855     // connected to the predicate Merge node; and when we call
856     // DeleteReachableAndDeadNodes(), the predicate Merge node and this
857     // Conditional's If node will be removed.
858     auto iter = merge_to_replacement.find(predicate_.node);
859     if (iter == merge_to_replacement.end()) {
860       return errors::Internal("Cannot find replacement for Merge node ",
861                               predicate_.node->name());
862     }
863     graph->AddEdge(iter->second.node, iter->second.index, if_node_, index++);
864   } else {
865     graph->AddEdge(const_cast<Node*>(predicate_.node), predicate_.index,
866                    if_node_, index++);
867   }
868   // Add function body inputs.
869   for (auto& arg : cond_arg_nodes_) {
870     if (arg.src_output == Graph::kControlSlot) {
871       graph->AddControlEdge(arg.src, if_node_);
872     } else {
873       graph->AddEdge(arg.src, arg.src_output, if_node_, index++);
874     }
875   }
876   for (Node* n : external_control_inputs_) {
877     graph->AddControlEdge(n, if_node_);
878   }
879   return Status::OK();
880 }
881 
AddOutputEdges(Graph * graph,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)882 Status Conditional::AddOutputEdges(
883     Graph* graph,
884     std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
885   VLOG(2) << "AddOutputEdges for " << if_node_->name();
886   int i = 0;
887   for (Node* node : merges_) {
888     TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i));
889     std::vector<const Edge*> edges(node->out_edges().begin(),
890                                    node->out_edges().end());
891     for (const Edge* edge : edges) {
892       Node* dst = edge->dst();
893       int dst_input = edge->dst_input();
894       if (edge->src_output() > 0) {
895         return errors::Unimplemented("Output of index (", edge->src_output(),
896                                      ") of merge node ",
897                                      FormatNodeForError(*node));
898       }
899 
900       bool control_edge = edge->IsControlEdge();
901       graph->RemoveEdge(edge);
902       if (control_edge) {
903         graph->AddControlEdge(if_node_, dst);
904       } else {
905         graph->AddEdge(if_node_, i, dst, dst_input);
906       }
907     }
908 
909     // Record corresponding output tensor in 'merge_to_replacement'.
910     (*merge_to_replacement)[node] = OutputTensor{if_node_, i};
911 
912     ++i;
913   }
914   for (Node* n : external_control_outputs_) {
915     graph->AddControlEdge(if_node_, n);
916   }
917 
918   return Status::OK();
919 }
920 
BuildAndReplace(Graph * graph,FunctionLibraryDefinition * library,std::unordered_map<Node *,OutputTensor> * merge_to_replacement)921 Status Conditional::BuildAndReplace(
922     Graph* graph, FunctionLibraryDefinition* library,
923     std::unordered_map<Node*, OutputTensor>* merge_to_replacement) {
924   VLOG(1) << "Build If and replace merge nodes "
925           << NodesToString(this->merges_);
926   if (replaced_) return Status::OK();
927 
928   TF_RETURN_IF_ERROR(ExtractBodies(graph));
929   TF_RETURN_IF_ERROR(BuildArgumentNodes());
930 
931   if (VLOG_IS_ON(3)) {
932     LOG(INFO) << "Extracted bodies:";
933     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
934       int branch_index = static_cast<int>(branch);
935       auto output = bodies_[branch_index].get();
936       LOG(INFO) << Branch_Name(branch) << ": "
937                 << DebugString(output->ToGraphDefDebug());
938     }
939   }
940 
941   TF_RETURN_IF_ERROR(BuildIfNode(graph, library));
942   TF_RETURN_IF_ERROR(AddInputEdges(graph, *merge_to_replacement));
943   TF_RETURN_IF_ERROR(AddOutputEdges(graph, merge_to_replacement));
944   TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
945 
946   // Check that the if_node doesn't feed into itself.
947   TF_RETURN_WITH_CONTEXT_IF_ERROR(
948       CheckNodeNotInCycle(if_node_, graph->num_node_ids()),
949       "Converting to If failed.");
950 
951   replaced_ = true;
952   return Status::OK();
953 }
954 
name() const955 string Conditional::name() const {
956   CHECK(!merges_.empty());
957   return absl::StrCat((*merges_.begin())->name(), "_if");
958 }
959 
AddIdentityNode(const Node * replacee,Node * if_node,int port)960 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
961                                           int port) {
962   NodeBuilder id_builder(replacee->name(), "Identity");
963   id_builder.Input(if_node, port);
964   string outside_compilation;
965   if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttrName,
966                   &outside_compilation)
967           .ok()) {
968     id_builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation);
969   }
970   Node* id;
971   TF_RETURN_IF_ERROR(id_builder.Finalize(graph_, &id));
972   state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
973   state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
974   return Status::OK();
975 }
976 
AddIfNode(const NodeDef & def,const Node * replacee,const OutputTensor & predicate)977 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
978                                              const Node* replacee,
979                                              const OutputTensor& predicate) {
980   Status status;
981   Node* ret = graph_->AddNode(def, &status);
982   TF_RETURN_IF_ERROR(status);
983   VLOG(1) << "Adding If for " << replacee->name();
984   StateMap::CondId id = state_map_.LookupCondId(replacee);
985   if (id) {
986     StateMap::CondState state = *id;
987     state.erase(predicate);
988     state_map_.ResetCondId(ret, state_map_.GetCondId(state));
989   } else {
990     state_map_.ResetCondId(ret, nullptr);
991   }
992 
993   state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
994 
995   return ret;
996 }
997 
PropagateUpdatedState(const Node * replacee)998 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
999   VLOG(2) << "Propagating update state for " << replacee->name() << " "
1000           << state_map_.CondStateToString(replacee);
1001   // Redo topological sort as the order could have changed.
1002   // TODO(jpienaar): The original topological order could also be updated
1003   // dynamically if needed.
1004   std::vector<Node*> rev_topo_order;
1005   GetPostOrder(*graph_, &rev_topo_order);
1006 
1007   // All the outputs of the new node could potentially be updated.
1008   std::unordered_set<Node*> changed;
1009   for (auto n : replacee->out_nodes())
1010     if (n->IsOp()) changed.insert(n);
1011 
1012   // Iterate through the changed/possible changed nodes in topological order.
1013   for (auto it = rev_topo_order.rbegin();
1014        it != rev_topo_order.rend() && !changed.empty(); ++it) {
1015     if (changed.find(*it) != changed.end()) {
1016       // Update the node state.
1017       Node* n = *it;
1018       StateMap::CondId old_state = state_map_.LookupCondId(n);
1019       state_map_.ResetCondId(n, nullptr);
1020       TF_RETURN_IF_ERROR(DetermineCondState(n));
1021       if (state_map_.LookupCondId(n) != old_state) {
1022         for (auto out : n->out_nodes())
1023           if (out->IsOp()) changed.insert(out);
1024       }
1025       changed.erase(n);
1026     }
1027   }
1028   return Status::OK();
1029 }
1030 
1031 // Returns the most restrictive branch of two branches or neither. This is the
1032 // meet operator of the BranchType lattice.
MeetBranch(const BranchType & lhs,const BranchType & rhs)1033 BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) {
1034   if (lhs == rhs) return lhs;
1035   if (lhs == BranchType::kNeither) return rhs;
1036   if (rhs == BranchType::kNeither) return lhs;
1037   if (lhs == BranchType::kBoth) return rhs;
1038   if (rhs == BranchType::kBoth) return lhs;
1039   return BranchType::kNeither;
1040 }
1041 
FindBranchOf(CondId id,OutputTensor predicate) const1042 BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
1043   if (IsEmpty(id)) return BranchType::kNeither;
1044   const CondState& nodes = *id;
1045   auto it = nodes.find(predicate);
1046   if (it == nodes.end()) return BranchType::kNeither;
1047   return it->second;
1048 }
1049 
JoinCondStatesNonMerge(StateMap::CondId src,StateMap::CondId dst)1050 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
1051     StateMap::CondId src, StateMap::CondId dst) {
1052   VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
1053           << "] and dst=" << DebugString(dst) << " [" << dst << "]";
1054 
1055   if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
1056   if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
1057 
1058   // Nothing to do if the CondState is the same.
1059   if (src == dst) return src;
1060 
1061   StateMap::CondState both = *src;
1062   for (const auto& kv : *dst) {
1063     auto it = both.find(kv.first);
1064     if (it == both.end()) {
1065       both.insert(kv);
1066     } else {
1067       if (it->second != kv.second) {
1068         if (it->second == BranchType::kNeither) {
1069           // BranchType for 'src' is kNeither. Use the BranchType in 'dst'.
1070           it->second = kv.second;
1071         } else if (kv.second == BranchType::kNeither) {
1072           // BranchType for 'dst' is kNeither. Use the BranchType in 'src'.
1073           // No need to change it->second.
1074         } else {
1075           return errors::InvalidArgument(
1076               "Graph contains node with inputs predicated on incompatible "
1077               "predicates: ",
1078               DebugString(src), " and ", DebugString(dst));
1079         }
1080       }
1081     }
1082   }
1083   return state_map_.GetCondId(both);
1084 }
1085 
JoinCondStatesMerge(Node * merge,StateMap::CondId src,StateMap::CondId dst)1086 StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
1087     Node* merge, StateMap::CondId src, StateMap::CondId dst) {
1088   // Determine the flow state when joining two states for a merge
1089   // node. Combining the two states for a merge node is effectively performing a
1090   // disjunction of the states along the different input edges. For a merge that
1091   // can be transformed into an If the two inputs paths have to have a predicate
1092   // on which they differ (e.g., along one edge predicate `p` has to hold while
1093   // on another it should not). This function first determines this predicate
1094   // and then the resultant state is the common path between the two inputs
1095   // followed by s(p, both).
1096   VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
1097           << DebugString(dst);
1098   if (state_map_.IsEmpty(dst)) return src;
1099   if (state_map_.IsEmpty(src)) {
1100     return errors::Internal("Merge node ", merge->name(),
1101                             " has input that's not in any CondContext.");
1102   }
1103 
1104   if (state_map_.IsDead(src)) return src;
1105   if (state_map_.IsDead(dst)) return dst;
1106 
1107   std::vector<StateMap::CondState::value_type> diff;
1108   StateMap::CondState merged;
1109   std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
1110                                 dst->end(), std::back_inserter(diff),
1111                                 CondStateLess());
1112   std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
1113                         std::inserter(merged, merged.begin()), CondStateLess());
1114 
1115   // Update mapping from merge node to predicate.
1116   if (diff.size() == 2) {
1117     auto pred = diff[0].first;
1118     bool different_branches = (diff[0].second != diff[1].second) &&
1119                               (diff[0].second == BranchType::kThenBranch ||
1120                                diff[0].second == BranchType::kElseBranch) &&
1121                               (diff[1].second == BranchType::kThenBranch ||
1122                                diff[1].second == BranchType::kElseBranch);
1123     if (!(pred == diff[1].first) || !different_branches)
1124       return errors::InvalidArgument(
1125           "Unable to determine predicate for merge node");
1126     merge_to_predicate_[merge] = pred;
1127   } else {
1128     return errors::InvalidArgument(
1129         "Merge of two inputs that differ on more than one predicate ",
1130         DebugString(src), " and ", DebugString(dst));
1131   }
1132 
1133   return state_map_.GetCondId(merged);
1134 }
1135 
StateAlongEdge(const Edge * e)1136 StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
1137   Node* src = e->src();
1138   StateMap::CondId id = state_map_.LookupCondId(e->src());
1139 
1140   // Dead nodes only propagate dead state.
1141   if (state_map_.IsDead(id)) return id;
1142 
1143   if (IsSwitch(src)) {
1144     StateMap::CondState state;
1145     if (id != nullptr) state = *id;
1146     OutputTensor predicate;
1147     TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
1148     if (e->IsControlEdge()) {
1149       // In gradients of tf.cond(), in each branch, we have a NoOp node as
1150       // control pivot. These NoOp nodes have control dependency from Switch
1151       // node. If we don't record this into CondState, branches might have
1152       // incorrect CondState (e.g. if the branch only has a Const data node).
1153       // We set it to kNeither because there is no way to tell whether it's
1154       // for true branch or false branch. This node's descendents might have
1155       // other incoming edges with defined BranchType, and we correctly handle
1156       // merging kNeither with other defined BranchType in StateAlongEdge().
1157       state[predicate] = BranchType::kNeither;
1158     } else {
1159       state[predicate] = BranchType(e->src_output());
1160     }
1161     return state_map_.GetCondId(state);
1162   }
1163   return id;
1164 }
1165 
DetermineCondStateMerge(Node * dst)1166 Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) {
1167   // Only Merge nodes with two inputs are supported, but if this is a redundant
1168   // merge, then the dead edge may already have been removed (if due to a
1169   // switch) and so the input count would be incorrect.
1170   if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
1171 
1172   int data_inputs = 0;
1173   for (auto e : dst->in_edges()) {
1174     Node* src = e->src();
1175     VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
1176             << state_map_.CondStateToString(src);
1177     if (!src->IsOp()) continue;
1178     if (!e->IsControlEdge()) ++data_inputs;
1179 
1180     StateMap::CondId prop = StateAlongEdge(e);
1181     auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
1182     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1183                                     FormatNodeForError(*dst));
1184     state_map_.ResetCondId(dst, id_or.ValueOrDie());
1185   }
1186 
1187   // Incomplete Merge nodes are not supported.
1188   if (data_inputs != 2) {
1189     return errors::Unimplemented(
1190         dst->name(), " only has ", data_inputs,
1191         " inputs, while only merge nodes with two inputs supported.");
1192   }
1193   return Status::OK();
1194 }
1195 
DetermineCondStateNonMerge(Node * dst)1196 Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
1197   // Handle non-merge join.
1198   for (auto e : dst->in_edges()) {
1199     VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
1200             << state_map_.CondStateToString(dst);
1201     Node* src = e->src();
1202     if (!src->IsOp()) continue;
1203 
1204     // Joining the state between the current and propagated state.
1205     StateMap::CondId prop = StateAlongEdge(e);
1206     auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
1207     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1208                                     FormatNodeForError(*dst));
1209     state_map_.ResetCondId(dst, id_or.ValueOrDie());
1210   }
1211   return Status::OK();
1212 }
1213 
RemoveRedundantMerge(Node * node)1214 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
1215   // Handle redundant merge nodes. A merge node is considered redundant if
1216   // one input edge is dead while the other has a value.
1217   if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
1218 
1219   const Edge* non_dead_edge = nullptr;
1220   for (auto e : node->in_edges()) {
1221     if (e->IsControlEdge()) continue;
1222     Node* src = e->src();
1223 
1224     // Handle merge with dead state.
1225     const auto& src_id = state_map_.LookupCondId(src);
1226     if (!state_map_.IsDead(src_id)) {
1227       non_dead_edge = e;
1228       break;
1229     }
1230   }
1231 
1232   if (non_dead_edge == nullptr) {
1233     return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
1234                                    " has no non-dead inputs.");
1235   }
1236   state_map_.MarkDead(node);
1237   VLOG(5) << "removing redundant merge: " << node->name();
1238   while (!node->out_edges().empty()) {
1239     const Edge* oe = *node->out_edges().begin();
1240     Node* dst_node = oe->dst();
1241     int dst_port = oe->dst_input();
1242     graph_->RemoveEdge(oe);
1243     graph_->AddEdge(non_dead_edge->src(),
1244                     dst_port == Graph::kControlSlot
1245                         ? Graph::kControlSlot
1246                         : non_dead_edge->src_output(),
1247                     dst_node, dst_port);
1248   }
1249   return Status::OK();
1250 }
1251 
RemoveRedundantSwitch(Node * node)1252 Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
1253   // Handle redundant switch nodes. A switch node is considered redundant if
1254   // the predicate of the switch already holds on the current branch. E.g., if
1255   // p is the predicate of the switch but p is already known to hold on this
1256   // branch, then the switch can be removed and the dead state propagated
1257   // along one. The checking of predicate is based on the exact predicate
1258   // (rather than boolean equivalence) and aimed at redundant switches as
1259   // currently generated by gradient code.
1260   StateMap::CondId dst_id = state_map_.LookupCondId(node);
1261   if (state_map_.IsDead(dst_id)) return Status::OK();
1262 
1263   BranchType b;
1264   OutputTensor pred;
1265   TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
1266 
1267   // Determine if we are already on a branch where the switch predicate is
1268   // true/false. Consider both the data and predicate to determine if the
1269   // node is redundant (skipping over identity node).
1270   b = state_map_.FindBranchOf(dst_id, pred);
1271   if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
1272     OutputTensor val;
1273     const Edge* e;
1274     TF_RETURN_IF_ERROR(node->input_edge(0, &e));
1275     val = OutputTensor(e->src(), e->src_output());
1276     while (IsIdentity(val.node)) {
1277       TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
1278       val = OutputTensor(e->src(), e->src_output());
1279     }
1280     b = state_map_.FindBranchOf(dst_id, val);
1281     if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
1282       return Status::OK();
1283   }
1284 
1285   VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
1286           << DebugString(dst_id);
1287   const Edge* value_edge;
1288   TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
1289   Node* val_node = value_edge->src();
1290   int val_port = value_edge->src_output();
1291   while (!node->out_edges().empty()) {
1292     auto e = *node->out_edges().begin();
1293     Node* dst_node = e->dst();
1294     int dst_input = e->dst_input();
1295     int switch_branch = e->src_output();
1296     graph_->RemoveEdge(e);
1297     if (switch_branch == Graph::kControlSlot) {
1298       if (IsMerge(dst_node)) {
1299         auto id_or = JoinCondStatesMerge(dst_node, dst_id,
1300                                          state_map_.LookupCondId(dst_node));
1301         TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
1302                                         FormatNodeForError(*dst_node));
1303         state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1304       } else {
1305         auto id_or =
1306             JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
1307         TF_RETURN_IF_ERROR(id_or.status());
1308         state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
1309       }
1310     } else if (BranchType(switch_branch) != b) {
1311       state_map_.MarkDead(dst_node);
1312       continue;
1313     }
1314     graph_->AddEdge(
1315         val_node,
1316         switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port,
1317         dst_node, dst_input);
1318   }
1319   return Status::OK();
1320 }
1321 
DetermineStates(std::vector<Node * > rev_topo_order)1322 Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
1323   // The state that is propagated along the given edge.
1324   for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
1325     Node* dst = *it;
1326     TF_RETURN_IF_ERROR(DetermineCondState(dst));
1327     TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
1328     if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
1329     if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
1330 
1331     VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
1332             << " @ " << state_map_.AncestorStateToString(dst);
1333     if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
1334   }
1335   return Status::OK();
1336 }
1337 
DetermineAncestorState(Node * dst)1338 Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
1339   StateMap::AncestorId id = nullptr;
1340   StateMap::AncestorState state;
1341 
1342   auto insert = [&](StateMap::AncestorId id, Node* src) {
1343     auto other_id = state_map_.LookupAncestorId(src);
1344     if (other_id != id && other_id != nullptr) {
1345       state.insert(other_id->begin(), other_id->end());
1346     }
1347     if (IsMerge(src)) {
1348       state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
1349     } else if (IsSwitch(src)) {
1350       OutputTensor pred;
1351       // For dead switch nodes, GetSwitchPredicate() will fail, and we use
1352       // the switch node directly as ancestor.
1353       if (GetSwitchPredicate(*src, &pred).ok()) {
1354         state.insert({pred, AncestorNode::AncestorNodeType::kPred});
1355       } else {
1356         state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
1357       }
1358     }
1359     return state_map_.GetAncestorId(state);
1360   };
1361 
1362   // Compute the union of all the switch/merge nodes that affects the input of
1363   // dst.
1364   for (auto e : dst->in_edges()) {
1365     Node* src = e->src();
1366     id = insert(id, src);
1367   }
1368   state_map_.ResetAncestorId(dst, id);
1369   return Status::OK();
1370 }
1371 
DeleteReachableAndDeadNodes(const std::vector<Node * > & merge_order)1372 void FunctionalizeCond::DeleteReachableAndDeadNodes(
1373     const std::vector<Node*>& merge_order) {
1374   // Delete all nodes that have been extracted or are reachable from
1375   // deleted/dead nodes. The input and outgoing edges should have already been
1376   // removed.
1377   std::deque<int> delete_nodes;
1378   std::vector<bool> deleted(graph_->num_node_ids(), false);
1379   // Don't try to delete source or sink nodes.
1380   deleted[graph_->kSourceId] = true;
1381   deleted[graph_->kSinkId] = true;
1382 
1383   // All remaining switch nodes that were not excluded from functionalization
1384   // according to `node_filter_` are not reachable from a merge node and
1385   // removed. This is to account for dead switch nodes.
1386   for (int s_id : switch_ids_) {
1387     Node* s = graph_->FindNodeId(s_id);
1388     if (s == nullptr) continue;
1389     for (const Edge* e : s->out_edges()) {
1390       // Control outputs of switch nodes (which are unconditionally executed if
1391       // the switch is) are not removed as they need not be part of a
1392       // conditional.
1393       if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1394     }
1395     // Only remove switch node if we have functionalized the corresponding
1396     // condition before (according to `node_filter_`).
1397     if (!node_filter_ || node_filter_(s)) {
1398       VLOG(2) << "Removing obsolete switch node " << s->name();
1399       deleted[s_id] = true;
1400       graph_->RemoveNode(s);
1401     }
1402   }
1403 
1404   // All merge nodes that were not excluded from functionalization according to
1405   // `node_filter_` should have been transformed at this point and we remove
1406   // them from the graph here.
1407   for (Node* m : merge_order) {
1408     for (const Edge* e : m->out_edges()) {
1409       // Similar to control outputs of switch nodes don't remove control
1410       // outputs of merge nodes.
1411       // TODO(jpienaar): Check cases where output edges still exist here vs
1412       // being removed in AddOutputEdges.
1413       if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
1414     }
1415     // Only remove merge node if we have functionalized the corresponding
1416     // condition before (according to `node_filter_`).
1417     if (!node_filter_ || node_filter_(m)) {
1418       VLOG(2) << "Removing obsolete merge node " << m->name();
1419       deleted[m->id()] = true;
1420       graph_->RemoveNode(m);
1421     }
1422   }
1423 
1424   // Enqueue all the dead nodes.
1425   for (Node* n : graph_->nodes()) {
1426     if (state_map_.IsDead(state_map_.LookupCondId(n))) {
1427       delete_nodes.push_back(n->id());
1428     }
1429   }
1430   // Remove dead nodes and nodes that are reachable from dead nodes.
1431   while (!delete_nodes.empty()) {
1432     int d_id = delete_nodes.front();
1433     delete_nodes.pop_front();
1434     if (deleted[d_id]) continue;
1435     Node* d = graph_->FindNodeId(d_id);
1436     // Switch and Merge nodes could have been deleted already.
1437     if (d == nullptr) continue;
1438     for (const Edge* e : d->out_edges()) {
1439       delete_nodes.push_back(e->dst()->id());
1440     }
1441     VLOG(2) << "Removing obsolete node " << d->name();
1442     deleted[d_id] = true;
1443     graph_->RemoveNode(d);
1444   }
1445 }
1446 
SortMergeNodes(std::vector<Node * > * merge_order)1447 void FunctionalizeCond::SortMergeNodes(std::vector<Node*>* merge_order) {
1448   // Sort merge nodes by nesting depth.
1449   using sort_pair = std::pair<int, Node*>;
1450   std::vector<sort_pair> inner_to_outer_merge_order;
1451   inner_to_outer_merge_order.reserve(merge_order->size());
1452   for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
1453     Node* merge = *it;
1454     StateMap::CondId id = state_map_.LookupCondId(merge);
1455     int depth = id != nullptr ? id->size() : 0;
1456     inner_to_outer_merge_order.emplace_back(depth, merge);
1457   }
1458   std::stable_sort(
1459       inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(),
1460       [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; });
1461   merge_order->clear();
1462   for (sort_pair t : inner_to_outer_merge_order) {
1463     merge_order->push_back(t.second);
1464   }
1465 }
1466 
FunctionalizeInternal()1467 Status FunctionalizeCond::FunctionalizeInternal() {
1468   // The general approach for converting a tf.cond (as lowered via switch/merge
1469   // nodes) to a functional if is as follows:
1470   // 1. Determine the topological order and collect all the switch and merge
1471   // nodes in the graph;
1472   // 2. Compute the predicates and dominance structure for all the nodes in the
1473   // graph - this includes which predicate must be true for a op to execute
1474   // (predicate values are considered directly rather than attempting to
1475   // determine deeper equivalence). We shall refer to this structure as the
1476   // CondState;
1477   // 3. Sort the merge nodes by nesting depth;
1478   // 4. Extract merge nodes together that have the same CondState and
1479   // AncestorState from the innermost to the outermost into IfOps;
1480   // Note: In the above only nodes that feed into a merge node will be
1481   // considered for functionalization.
1482   // Note: Nodes for which `node_filter_` returns false are excluded.
1483 
1484   // Perform a DFS over the graph and
1485   // * Determine the reverse topological order of the nodes (there should be no
1486   //   cycles at this point so the post-order numbering corresponds to the
1487   //   reverse topological sorting);
1488   // * Record reverse topological for merge and switch nodes;
1489   std::vector<Node*> rev_topo_order;
1490   std::vector<Node*> merge_order;
1491   DFS(*graph_, nullptr, [&](Node* n) {
1492     // Only collect switch and merge nodes that are not filtered out, those form
1493     // the conditions that will be functionalized.
1494     if (!node_filter_ || node_filter_(n)) {
1495       if (IsSwitch(n)) {
1496         AddSwitchId(n->id());
1497       }
1498       if (IsMerge(n)) {
1499         merge_order.push_back(n);
1500       }
1501     }
1502     // Collect all other nodes here, independent of `node_filter_`, because they
1503     // might belong to a condition that should be functionalized.
1504     if (n->IsOp()) {
1505       rev_topo_order.push_back(n);
1506     }
1507   });
1508 
1509   // No merges to functionalize.
1510   if (merge_order.empty()) {
1511     // No merges mean no switch values consumed (as only considering values
1512     // fetchable as output of merge);
1513     DeleteReachableAndDeadNodes(merge_order);
1514     return Status::OK();
1515   }
1516 
1517   TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
1518   if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
1519 
1520   // Determine the shapes of the ops in the graph.
1521   ShapeRefiner shape_refiner{graph_->versions().producer(),
1522                              graph_->op_registry()};
1523   std::vector<Node*> nodes;
1524   GetReversePostOrder(*graph_, &nodes);
1525   for (auto node : nodes) {
1526     if (!shape_refiner.AddNode(node).ok()) {
1527       LOG(WARNING) << "Couldn't deduce shape for " << node->name();
1528     }
1529   }
1530 
1531   // Sort the merge nodes from innermost outwards.
1532   SortMergeNodes(&merge_order);
1533 
1534   // Cluster merge nodes by (CondId, AncestorId, predicate) in order of
1535   // nesting. (CondId, AncestorId) is not enough, e.g.
1536   //   pred1 = array_ops.placeholder(dtypes.bool, name='pred1')
1537   //   pred2 = array_ops.placeholder(dtypes.bool, name='pred2')
1538   //   cond1 = control_flow_ops.cond(pred1, ...)
1539   //   cond2 = control_flow_ops.cond(pred2, ...)
1540   //   cond3 = control_flow_ops.cond(pred1, use cond1 and cond2)
1541   //   cond4 = control_flow_ops.cond(pred2, use cond1 and cond2)
1542   // cond3 and cond4 have the same (CondId, AncestorId), but they should not
1543   // be merged into one "If" node (because they have different predicates).
1544   std::deque<std::vector<Node*>> merge_clusters;
1545   std::map<ClusterTuple, int, ClusterTupleLessThan> merge_cluster_index;
1546   for (Node* merge : merge_order) {
1547     auto cond_id = state_map_.LookupCondId(merge);
1548     if (state_map_.IsDead(cond_id)) continue;
1549 
1550     auto predicate = merge_to_predicate_.find(merge);
1551     if (predicate == merge_to_predicate_.end()) {
1552       return errors::Internal("Cannot find predicate for Merge node ",
1553                               merge->name());
1554     }
1555 
1556     ClusterTuple key = std::make_tuple(
1557         cond_id, state_map_.LookupAncestorId(merge), predicate->second);
1558     auto idx = merge_cluster_index.find(key);
1559     if (idx == merge_cluster_index.end()) {
1560       merge_cluster_index[key] = merge_clusters.size();
1561       merge_clusters.push_back({merge});
1562     } else {
1563       merge_clusters[idx->second].emplace_back(merge);
1564     }
1565   }
1566 
1567   // Extract the conditionals from inner most to outer most. Extracting from
1568   // innermost to outermost enables the extraction pass to stop once it
1569   // encounters a Switch node instead of having to keep track of Switch/Merge
1570   // nodes seen.
1571   for (const auto& cluster : merge_clusters) {
1572     // Construct a Conditional with the predicate of the merge.
1573     Conditional cond(merge_to_predicate_.at(cluster.front()), this, &state_map_,
1574                      shape_refiner);
1575     for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
1576     TF_RETURN_IF_ERROR(
1577         cond.BuildAndReplace(graph_, library_, &merge_to_replacement_));
1578 
1579     if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
1580   }
1581 
1582   DeleteReachableAndDeadNodes(merge_order);
1583 
1584   return Status::OK();
1585 }
1586 
DumpGraphWithCondState(const string & name)1587 void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
1588   const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup";
1589 
1590   for (Node* n : graph_->nodes()) {
1591     n->ClearAttr(kCondGroupDebugAttr);
1592     n->AddAttr(kCondGroupDebugAttr,
1593                absl::StrCat(state_map_.CondStateToString(n), "_",
1594                             state_map_.AncestorStateToString(n)));
1595   }
1596   LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
1597             << DumpGraphToFile(absl::StrCat("functionalize_cond_", name),
1598                                *graph_, library_);
1599 }
1600 
AddSwitchId(int switch_id)1601 void FunctionalizeCond::AddSwitchId(int switch_id) {
1602   switch_ids_.push_back(switch_id);
1603 }
1604 
Functionalize(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1605 Status FunctionalizeCond::Functionalize(Graph* graph,
1606                                         FunctionLibraryDefinition* library,
1607                                         const NodeFilter& node_filter) {
1608   VLOG(1) << "FunctionalizeCond::Functionalize";
1609   FunctionalizeCond fc(graph, library, node_filter);
1610   return fc.FunctionalizeInternal();
1611 }
1612 
1613 }  // namespace functionalize_cond
1614 
FunctionalizeCond(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)1615 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
1616                          const NodeFilter& node_filter) {
1617   // FunctionalizeControlFlow is invoked for every function, so the loops's
1618   // bodies and conditionals that were extracted into functions will be handled
1619   // in successive invocations.
1620   return functionalize_cond::FunctionalizeCond::Functionalize(graph, library,
1621                                                               node_filter);
1622 }
1623 
1624 }  // namespace tensorflow
1625