1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ 18 19 #include <deque> 20 21 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" 22 #include "tensorflow/compiler/xla/status_macros.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/graph/graph.h" 25 26 namespace tensorflow { 27 28 // Functionalize all the switch-merge nodes of a loop-free graph into If 29 // nodes. That is, attempt to transform every remaining switch and merge nodes 30 // in the graph into If nodes. 31 // 32 // If `node_filter` is defined, then only conditions for whose nodes 33 // `node_filter` returns true are functionalized. 34 // 35 // Preconditions: 36 // a) Same as for `FunctionalizeControlFlow` (see comment there). 37 // b) While loops must have been functionalized before according to 38 // `node_filter` (e.g., by calling `FunctionalizeWhileLoop` with the same 39 // filter before calling this function). 40 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, 41 const NodeFilter& node_filter = {}); 42 43 // Internal functions/classes exposed for testing purposes. 44 namespace functionalize_cond { 45 46 // All nodes are assumed to be either in no branch, then branch, else branch, 47 // or both branches (such as merge nodes). 48 // The code below relies on Else and Then being 0 and 1 (corresponding to the 49 // switch outputs). Both and Neither are arbitrary. 50 enum class BranchType { 51 kElseBranch = 0, 52 kThenBranch = 1, 53 kBoth = 2, 54 kNeither = 3, 55 }; 56 57 // When we keep track of which switch/merge node's feed into a node, we record 58 // 1) predicate for non-dead switch node, 59 // 2) the switch node itself for dead switch node, 60 // 3) the merge node itself for merge node. 61 // Case 1) is an optimization. With this optimization, if there are nodes from 62 // different switch nodes but those switch nodes have the same predicate, the 63 // nodes will still have same AncestorState, and they will be clustered into a 64 // single "If". 65 struct AncestorNode { 66 enum class AncestorNodeType { 67 kPred = 0, 68 kSwitch = 1, 69 kMerge = 2, 70 }; 71 72 OutputTensor output_tensor; 73 AncestorNodeType type; 74 75 // Compare two AncestorNodes by (node id, index, type). 76 bool operator<(const AncestorNode& other) const; 77 bool operator==(const AncestorNode& other) const; 78 79 struct Hash { 80 size_t operator()(const AncestorNode&) const; 81 }; 82 }; 83 84 // StateMap is responsible for mapping from each graph Node to 85 // * a CondState, where each CondState is a map from predicate to branch (i,e., 86 // what predicates have to hold or not hold). 87 // * a AncestorState, where each AncestorState is a set of switch/merge nodes 88 // that are an ancestor of the node in the graph; 89 // For efficiency, this class interns the CondState (AncestorState), so that 90 // CondState (AncestorState) equality comparisons are simply pointer 91 // comparisons. 92 class StateMap { 93 public: 94 explicit StateMap(Graph* graph); 95 96 // Compare two OutputTensors by (node id, index). 97 struct OutputTensorLess { 98 bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; 99 }; 100 101 // A node in the graph is executed when multiple conditions hold. Keep track 102 // of the predicates that must hold for a node to execute. 103 using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>; 104 105 // Every unique ID is mapped to a CondState. 106 using CondId = const CondState*; 107 108 // Keep track of which switch/merge node's feed into a node's values. 109 using AncestorState = std::set<AncestorNode>; 110 111 // Every unique ID is mapped to a AncestorState. 112 using AncestorId = const AncestorState*; 113 114 // Returns the CondId for a given node. 115 CondId LookupCondId(const Node* node) const; 116 117 // Returns the unique CondId for CondState. 118 CondId GetCondId(const CondState& state); 119 120 // Resets the CondId for a given node. 121 void ResetCondId(const Node* node, CondId id); 122 123 // Returns the AncestorId for a given node. 124 AncestorId LookupAncestorId(const Node* node) const; 125 126 // Returns the unique AncestorId for CondState. 127 AncestorId GetAncestorId(const AncestorState& state); 128 129 // Resets the AncestorId for a given node. 130 void ResetAncestorId(const Node* node, AncestorId id); 131 132 // Marks `node` as dead. 133 void MarkDead(const Node* node); 134 135 // Determine branch execution of CondState. 136 BranchType FindBranchOf(CondId id, OutputTensor predicate) const; 137 138 // Returns textual representation of node's CondState. 139 string CondStateToString(const Node* node) const; 140 string CondStateToString(CondId id) const; 141 142 // Returns textual representation of node's AncestorState. 143 string AncestorStateToString(const Node* node) const; 144 145 // Returns whether the cond state is the dead state. 146 bool IsDead(CondId id) const; 147 148 // Returns whether the cond state is the empty state. 149 bool IsEmpty(CondId id) const; 150 151 private: 152 // Hash for CondState and AncestorState. 153 struct Hash { 154 size_t operator()(const CondState& map) const; 155 size_t operator()(const AncestorState& map) const; 156 }; 157 158 // Set to keep track of unique CondStates. 159 // Pointers to the entries in the unordered set are used as identifiers: 160 // unordered_set guarantees that the pointers remain the same. 161 std::unordered_set<CondState, Hash> condstate_set_; 162 163 // Mapping from Node id to CondId. 164 std::vector<CondId> node_to_condid_map_; 165 166 // Track the CondId for newly inserted nodes. We use a vector to quickly map 167 // from Node id in the original graph to the CondId, but there will be nodes 168 // added to the original graph (such as If nodes) whose CondState needs to be 169 // tracked too. 170 std::unordered_map<int, CondId> added_node_condid_mapping_; 171 172 // AncestorId variants of the CondId members. 173 std::unordered_set<AncestorState, Hash> ancestorstate_set_; 174 std::vector<AncestorId> node_to_ancestorid_map_; 175 std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_; 176 177 // Identifier of the dead flow state. The empty flow state is represented with 178 // a nullptr. 179 CondId dead_id_; 180 }; 181 182 // FunctionalizeCond groups all the state used by functionalizing conditionals 183 // of the given graph together. 184 class FunctionalizeCond { 185 public: 186 // See comment for function `FunctionalizeCond`. 187 static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library, 188 const NodeFilter& node_filter); 189 190 // Build identity node with the same name as the merge that will be replaced 191 // in case the output is fetched/colocated. 192 Status AddIdentityNode(const Node* replacee, Node* if_node, int port); 193 194 // Add a If node to the graph defined by def that will, amongst other, replace 195 // replacee in the graph. 196 xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee, 197 const OutputTensor& predicate); 198 199 // Propagates the state of a newly inserted node. 200 Status PropagateUpdatedState(const Node* replacee); 201 202 // Dump graph with the CondState annotated. 203 void DumpGraphWithCondState(const string& name); 204 205 // Adds `switch_id` to the list of Switch node ids. 206 void AddSwitchId(int switch_id); 207 208 private: 209 FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, 210 const NodeFilter& node_filter); 211 212 // Performs the actual cond functionalization. Iterate over groups of merge 213 // nodes (linked by common predicates & ancestor IDs), from innermost to 214 // outermost, and extract into If nodes. 215 Status FunctionalizeInternal(); 216 217 // Returns the forward flow state propagated along edge `e`. 218 // This may modify state_map_. 219 StateMap::CondId StateAlongEdge(const Edge* e); 220 221 // Determines the CondState and AncestorState of all the nodes in the given 222 // vector where the input is expected in reverse topological order. 223 // This populates the state_map_. 224 Status DetermineStates(std::vector<Node*> rev_topo_order); 225 226 // Determine the CondState for a given node using the incoming edges 227 // to the node. Note: it is expected that this node's CondState is only 228 // determined once its input's CondState is. DetermineCondState(Node * dst)229 Status DetermineCondState(Node* dst) { 230 if (IsMerge(dst)) return DetermineCondStateMerge(dst); 231 return DetermineCondStateNonMerge(dst); 232 } 233 234 // Helper functions for DetermineCondState. 235 Status DetermineCondStateNonMerge(Node* dst); 236 Status DetermineCondStateMerge(Node* dst); 237 238 // Determines the dst node's CondState by joining the src and dst's CondState 239 // where either the dst node is a merge or not. 240 // These may modify state_map_. 241 xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge, 242 StateMap::CondId src, 243 StateMap::CondId dst); 244 xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src, 245 StateMap::CondId dst); 246 247 // Determines which switch/merge nodes are ancestors of this node. 248 Status DetermineAncestorState(Node* dst); 249 250 // Checks if a merge node is redundant and if so removes it from the graph. 251 Status RemoveRedundantMerge(Node* node); 252 253 // Checks if a switch node is redundant and if so removes it from the graph. 254 Status RemoveRedundantSwitch(Node* node); 255 256 // Sorts merge nodes (in reverse topological order) in order of increasing 257 // nesting depth. 258 void SortMergeNodes(std::vector<Node*>* merge_order); 259 260 // Deletes all nodes in/consumers reachable from switch/merge nodes that were 261 // extracted. 262 void DeleteReachableAndDeadNodes(const std::vector<Node*>& merge_order); 263 264 // Member used to unique the CondState to a unique CondId (AncestorState to a 265 // unique AncestorId) and keep track of CondState/CondId 266 // (AncestorState/AncestorId) per Node. 267 StateMap state_map_; 268 269 // Mapping from merge nodes to predicate. 270 std::unordered_map<Node*, OutputTensor> merge_to_predicate_; 271 272 // Mapping from merge nodes to corresponding If node outputs. 273 std::unordered_map<Node*, OutputTensor> merge_to_replacement_; 274 275 FunctionLibraryDefinition* library_; 276 Graph* graph_; 277 278 friend class FunctionalizeCondTest; 279 280 std::vector<int> switch_ids_; 281 282 // Controls which nodes are skipped for functionalization. 283 NodeFilter node_filter_ = {}; 284 }; 285 286 } // namespace functionalize_cond 287 288 } // namespace tensorflow 289 290 #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ 291