• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
18 
19 #include <deque>
20 
21 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/graph/graph.h"
25 
26 namespace tensorflow {
27 
28 // Functionalize all the switch-merge nodes of a loop-free graph into If
29 // nodes. That is, attempt to transform every remaining switch and merge nodes
30 // in the graph into If nodes.
31 //
32 // If `node_filter` is defined, then only conditions for whose nodes
33 // `node_filter` returns true are functionalized.
34 //
35 // Preconditions:
36 // a) Same as for `FunctionalizeControlFlow` (see comment there).
37 // b) While loops must have been functionalized before according to
38 //    `node_filter` (e.g., by calling `FunctionalizeWhileLoop` with the same
39 //    filter before calling this function).
40 Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
41                          const NodeFilter& node_filter = {});
42 
43 // Internal functions/classes exposed for testing purposes.
44 namespace functionalize_cond {
45 
46 // All nodes are assumed to be either in no branch, then branch, else branch,
47 // or both branches (such as merge nodes).
48 // The code below relies on Else and Then being 0 and 1 (corresponding to the
49 // switch outputs). Both and Neither are arbitrary.
50 enum class BranchType {
51   kElseBranch = 0,
52   kThenBranch = 1,
53   kBoth = 2,
54   kNeither = 3,
55 };
56 
57 // When we keep track of which switch/merge node's feed into a node, we record
58 // 1) predicate for non-dead switch node,
59 // 2) the switch node itself for dead switch node,
60 // 3) the merge node itself for merge node.
61 // Case 1) is an optimization. With this optimization, if there are nodes from
62 // different switch nodes but those switch nodes have the same predicate, the
63 // nodes will still have same AncestorState, and they will be clustered into a
64 // single "If".
65 struct AncestorNode {
66   enum class AncestorNodeType {
67     kPred = 0,
68     kSwitch = 1,
69     kMerge = 2,
70   };
71 
72   OutputTensor output_tensor;
73   AncestorNodeType type;
74 
75   // Compare two AncestorNodes by (node id, index, type).
76   bool operator<(const AncestorNode& other) const;
77   bool operator==(const AncestorNode& other) const;
78 
79   struct Hash {
80     size_t operator()(const AncestorNode&) const;
81   };
82 };
83 
84 // StateMap is responsible for mapping from each graph Node to
85 // * a CondState, where each CondState is a map from predicate to branch (i,e.,
86 //   what predicates have to hold or not hold).
87 // * a AncestorState, where each AncestorState is a set of switch/merge nodes
88 //   that are an ancestor of the node in the graph;
89 // For efficiency, this class interns the CondState (AncestorState), so that
90 // CondState (AncestorState) equality comparisons are simply pointer
91 // comparisons.
92 class StateMap {
93  public:
94   explicit StateMap(Graph* graph);
95 
96   // Compare two OutputTensors by (node id, index).
97   struct OutputTensorLess {
98     bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const;
99   };
100 
101   // A node in the graph is executed when multiple conditions hold. Keep track
102   // of the predicates that must hold for a node to execute.
103   using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>;
104 
105   // Every unique ID is mapped to a CondState.
106   using CondId = const CondState*;
107 
108   // Keep track of which switch/merge node's feed into a node's values.
109   using AncestorState = std::set<AncestorNode>;
110 
111   // Every unique ID is mapped to a AncestorState.
112   using AncestorId = const AncestorState*;
113 
114   // Returns the CondId for a given node.
115   CondId LookupCondId(const Node* node) const;
116 
117   // Returns the unique CondId for CondState.
118   CondId GetCondId(const CondState& state);
119 
120   // Resets the CondId for a given node.
121   void ResetCondId(const Node* node, CondId id);
122 
123   // Returns the AncestorId for a given node.
124   AncestorId LookupAncestorId(const Node* node) const;
125 
126   // Returns the unique AncestorId for CondState.
127   AncestorId GetAncestorId(const AncestorState& state);
128 
129   // Resets the AncestorId for a given node.
130   void ResetAncestorId(const Node* node, AncestorId id);
131 
132   // Marks `node` as dead.
133   void MarkDead(const Node* node);
134 
135   // Determine branch execution of CondState.
136   BranchType FindBranchOf(CondId id, OutputTensor predicate) const;
137 
138   // Returns textual representation of node's CondState.
139   string CondStateToString(const Node* node) const;
140   string CondStateToString(CondId id) const;
141 
142   // Returns textual representation of node's AncestorState.
143   string AncestorStateToString(const Node* node) const;
144 
145   // Returns whether the cond state is the dead state.
146   bool IsDead(CondId id) const;
147 
148   // Returns whether the cond state is the empty state.
149   bool IsEmpty(CondId id) const;
150 
151  private:
152   // Hash for CondState and AncestorState.
153   struct Hash {
154     size_t operator()(const CondState& map) const;
155     size_t operator()(const AncestorState& map) const;
156   };
157 
158   // Set to keep track of unique CondStates.
159   // Pointers to the entries in the unordered set are used as identifiers:
160   // unordered_set guarantees that the pointers remain the same.
161   std::unordered_set<CondState, Hash> condstate_set_;
162 
163   // Mapping from Node id to CondId.
164   std::vector<CondId> node_to_condid_map_;
165 
166   // Track the CondId for newly inserted nodes. We use a vector to quickly map
167   // from Node id in the original graph to the CondId, but there will be nodes
168   // added to the original graph (such as If nodes) whose CondState needs to be
169   // tracked too.
170   std::unordered_map<int, CondId> added_node_condid_mapping_;
171 
172   // AncestorId variants of the CondId members.
173   std::unordered_set<AncestorState, Hash> ancestorstate_set_;
174   std::vector<AncestorId> node_to_ancestorid_map_;
175   std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_;
176 
177   // Identifier of the dead flow state. The empty flow state is represented with
178   // a nullptr.
179   CondId dead_id_;
180 };
181 
182 // FunctionalizeCond groups all the state used by functionalizing conditionals
183 // of the given graph together.
184 class FunctionalizeCond {
185  public:
186   // See comment for function `FunctionalizeCond`.
187   static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library,
188                               const NodeFilter& node_filter);
189 
190   // Build identity node with the same name as the merge that will be replaced
191   // in case the output is fetched/colocated.
192   Status AddIdentityNode(const Node* replacee, Node* if_node, int port);
193 
194   // Add a If node to the graph defined by def that will, amongst other, replace
195   // replacee in the graph.
196   xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee,
197                                  const OutputTensor& predicate);
198 
199   // Propagates the state of a newly inserted node.
200   Status PropagateUpdatedState(const Node* replacee);
201 
202   // Dump graph with the CondState annotated.
203   void DumpGraphWithCondState(const string& name);
204 
205   // Adds `switch_id` to the list of Switch node ids.
206   void AddSwitchId(int switch_id);
207 
208  private:
209   FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
210                     const NodeFilter& node_filter);
211 
212   // Performs the actual cond functionalization. Iterate over groups of merge
213   // nodes (linked by common predicates & ancestor IDs), from innermost to
214   // outermost, and extract into If nodes.
215   Status FunctionalizeInternal();
216 
217   // Returns the forward flow state propagated along edge `e`.
218   // This may modify state_map_.
219   StateMap::CondId StateAlongEdge(const Edge* e);
220 
221   // Determines the CondState and AncestorState of all the nodes in the given
222   // vector where the input is expected in reverse topological order.
223   // This populates the state_map_.
224   Status DetermineStates(std::vector<Node*> rev_topo_order);
225 
226   // Determine the CondState for a given node using the incoming edges
227   // to the node. Note: it is expected that this node's CondState is only
228   // determined once its input's CondState is.
DetermineCondState(Node * dst)229   Status DetermineCondState(Node* dst) {
230     if (IsMerge(dst)) return DetermineCondStateMerge(dst);
231     return DetermineCondStateNonMerge(dst);
232   }
233 
234   // Helper functions for DetermineCondState.
235   Status DetermineCondStateNonMerge(Node* dst);
236   Status DetermineCondStateMerge(Node* dst);
237 
238   // Determines the dst node's CondState by joining the src and dst's CondState
239   // where either the dst node is a merge or not.
240   // These may modify state_map_.
241   xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge,
242                                                       StateMap::CondId src,
243                                                       StateMap::CondId dst);
244   xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
245                                                          StateMap::CondId dst);
246 
247   // Determines which switch/merge nodes are ancestors of this node.
248   Status DetermineAncestorState(Node* dst);
249 
250   // Checks if a merge node is redundant and if so removes it from the graph.
251   Status RemoveRedundantMerge(Node* node);
252 
253   // Checks if a switch node is redundant and if so removes it from the graph.
254   Status RemoveRedundantSwitch(Node* node);
255 
256   // Sorts merge nodes (in reverse topological order) in order of increasing
257   // nesting depth.
258   void SortMergeNodes(std::vector<Node*>* merge_order);
259 
260   // Deletes all nodes in/consumers reachable from switch/merge nodes that were
261   // extracted.
262   void DeleteReachableAndDeadNodes(const std::vector<Node*>& merge_order);
263 
264   // Member used to unique the CondState to a unique CondId (AncestorState to a
265   // unique AncestorId) and keep track of CondState/CondId
266   // (AncestorState/AncestorId) per Node.
267   StateMap state_map_;
268 
269   // Mapping from merge nodes to predicate.
270   std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
271 
272   // Mapping from merge nodes to corresponding If node outputs.
273   std::unordered_map<Node*, OutputTensor> merge_to_replacement_;
274 
275   FunctionLibraryDefinition* library_;
276   Graph* graph_;
277 
278   friend class FunctionalizeCondTest;
279 
280   std::vector<int> switch_ids_;
281 
282   // Controls which nodes are skipped for functionalization.
283   NodeFilter node_filter_ = {};
284 };
285 
286 }  // namespace functionalize_cond
287 
288 }  // namespace tensorflow
289 
290 #endif  // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_
291