1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ 18 19 #include <set> 20 #include <string> 21 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/graph/graph.h" 28 #include "tensorflow/core/graph/tensor_id.h" 29 #include "tensorflow/core/grappler/graph_view.h" 30 #include "tensorflow/core/grappler/op_types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace tensorflow { 35 namespace grappler { 36 37 const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl"; 38 39 // A utility class to simplify the traversal of a GraphDef that, unlike 40 // GraphView, supports updating the graph. Note that you should not modify the 41 // graph separately, because the view will get out of sync. 42 43 class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> { 44 public: MutableGraphView(GraphDef * graph)45 explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) { 46 for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node); 47 for (NodeDef& node : *graph->mutable_node()) AddAndDedupFanouts(&node); 48 } 49 50 // Lookup fanouts/fanins using immutable ports. 51 using GraphViewInternal::GetFanout; 52 const absl::flat_hash_set<InputPort>& GetFanout( 53 const GraphView::OutputPort& port) const; 54 55 using GraphViewInternal::GetFanin; 56 absl::flat_hash_set<OutputPort> GetFanin( 57 const GraphView::InputPort& port) const; 58 59 using GraphViewInternal::GetRegularFanin; 60 const OutputPort GetRegularFanin(const GraphView::InputPort& port) const; 61 62 // Adds a new node to graph and updates the view. Returns a pointer to the 63 // node in graph. 64 NodeDef* AddNode(NodeDef&& node); 65 66 // Adds all nodes from the `subgraph` to the underlying graph and updates the 67 // view. `subgraph` doesn't have to be a valid graph definition on it's own, 68 // it can have edges to the nodes that are not in it, however after adding 69 // it to the underlying graph, final graph must be valid. 70 // 71 // If subgraph function library is not empty, all new functions will be added 72 // to the graph. Functions that appear with the same name in both subgraph and 73 // the graph represented by *this, must have identical function definitions. 74 // 75 // IMPORTANT: All nodes and functions of the given subgraph moved into the 76 // underlying graph, which leaves subgraph in valid but undefined state. 77 Status AddSubgraph(GraphDef&& subgraph); 78 79 // Updates node `node_name` op, device, and attributes. This will clear any 80 // existing attributes. If it is not possible to update the node or if the 81 // node does not exist, an error will be returned and nothing will be modified 82 // in the graph. 83 Status UpdateNode(absl::string_view node_name, absl::string_view op, 84 absl::string_view device, 85 absl::Span<const std::pair<string, AttrValue>> attrs); 86 87 // Updates node `from_node_name` name to `to_node_name`. If `to_node_name` is 88 // in use, node `from_node_name` does not exist, or node `from_node_name` has 89 // fanouts and `update_fanouts` is set to false, an error will be returned and 90 // nothing will be modified in the graph. 91 Status UpdateNodeName(absl::string_view from_node_name, 92 absl::string_view to_node_name, bool update_fanouts); 93 94 // Swap node names `from_node_name` and `to_node_name`. Self loops of one node 95 // are removed by updating the inputs introducing self loops to use the other 96 // node's name. Setting `update_fanouts` to false will exclude other fanouts 97 // from having their inputs updated, but inputs introducing self loops will 98 // always be updated regardless of `update_fanouts. 99 // 100 // Example: 101 // 1. foo(other:3, bar:2, ^bar) 102 // 2. bar(foo:3, other:1, foo:1, ^foo) 103 // 3. other(foo:5, bar:6) 104 // 105 // After calling SwapNodeNames("foo", "bar", false): 106 // 1. bar(other:3, foo:2, ^foo) 107 // 2. foo(bar:3, other:1, bar:1, ^bar) 108 // 3. other(foo:5, bar:6) 109 // 110 // After calling SwapNodeNames("foo", "bar", true): 111 // 1. bar(other:3, foo:2, ^foo) 112 // 2. foo(bar:3, other:1, bar:1, ^bar) 113 // 3. other(bar:5, foo:6) 114 // 115 // If it is not possible to swap node names (i.e. nodes do not exist or Switch 116 // control dependency may be introduced), an error will be returned and 117 // nothing will be modified in the graph. 118 Status SwapNodeNames(absl::string_view from_node_name, 119 absl::string_view to_node_name, bool update_fanouts); 120 121 // Updates all fanouts (input ports fetching output tensors) from 122 // `from_node_name` to the `to_node_name`, including control dependencies. 123 // 124 // Example: We have 3 nodes that use `bar` node output tensors as inputs: 125 // 1. foo1(bar:0, bar:1, other:0) 126 // 2. foo2(bar:1, other:1) 127 // 3. foo3(other:2, ^bar) 128 // 129 // After calling UpdateFanouts(bar, new_bar): 130 // 1. foo1(new_bar:0, new_bar:1, other:0) 131 // 2. foo2(new_bar:1, other:1) 132 // 3. foo3(other:2, ^new_bar) 133 Status UpdateFanouts(absl::string_view from_node_name, 134 absl::string_view to_node_name); 135 136 // Adds regular fanin `fanin` to node `node_name`. If the node or fanin do not 137 // exist in the graph, nothing will be modified in the graph. Otherwise fanin 138 // will be added after existing non control dependency fanins. Control 139 // dependencies will be deduped. To add control dependencies, use 140 // AddControllingFanin. 141 Status AddRegularFanin(absl::string_view node_name, const TensorId& fanin); 142 143 // Adds regular fanin `fanin` to node `node_name` at port `port`. If the node 144 // or fanin do not exist in the graph, nothing will be modified in the graph. 145 // Otherwise fanin will be inserted at port `port`. Control dependencies will 146 // be deduped. To add control dependencies, use AddControllingFanin. 147 // 148 // If the port is not a valid port (less than 0 or greater than the number of 149 // regular fanins), this will result in an error and the node will not be 150 // modified. 151 Status AddRegularFaninByPort(absl::string_view node_name, int port, 152 const TensorId& fanin); 153 154 // Adds control dependency `fanin` to the target node named `node_name`. To 155 // add regular fanins, use AddRegularFanin. 156 // 157 // Case 1: If the fanin is not a Switch node, the control dependency is simply 158 // added to the target node: 159 // 160 // fanin -^> target node. 161 // 162 // Case 2: If the fanin is a Switch node, we cannot anchor a control 163 // dependency on it, because unlike other nodes, only one of its outputs will 164 // be generated when the node is activated. In this case, we try to find an 165 // Identity/IdentityN node in the fanout of the relevant port of the Switch 166 // and add it as a fanin to the target node. If no such Identity/IdentityN 167 // node can be found, a new Identity node will be created. In both cases, we 168 // end up with: 169 // 170 // fanin -> Identity{N} -^> target node. 171 // 172 // If the control dependency being added is redundant (control dependency 173 // already exists or control dependency can be deduped from regular fanins), 174 // this will not result in an error and the node will not be modified. 175 Status AddControllingFanin(absl::string_view node_name, 176 const TensorId& fanin); 177 178 // Removes regular fanin `fanin` from node `node_name`. If the node or fanin 179 // do not exist in the graph, nothing will be modified in the graph. If there 180 // are multiple inputs that match the fanin, all of them will be removed. To 181 // remove controlling fanins, use RemoveControllingFanin. 182 // 183 // If the fanin being removed doesn't exist in the node's inputs, this will 184 // not result in an error and the node will not be modified. 185 Status RemoveRegularFanin(absl::string_view node_name, const TensorId& fanin); 186 187 // Removes regular fanin at port `port` from node `node_name`. If the node 188 // does not exist in the graph, nothing will be modified in the graph. 189 // To remove controlling fanins, use RemoveControllingFanin. 190 // 191 // If the port is not a valid port (less than 0 or greater than the last index 192 // of the regular fanins), this will result in an error and the node will not 193 // be modified. 194 Status RemoveRegularFaninByPort(absl::string_view node_name, int port); 195 196 // Removes control dependency `fanin_node_name` from the target node named 197 // `node_name`. If the node or fanin do not exist in the graph, nothing will 198 // be modified in the graph. To remove regular fanins, use RemoveRegualrFanin. 199 // 200 // If the fanin being removed doesn't exist in the node's inputs, this will 201 // not result in an error and the node will not be modified. 202 Status RemoveControllingFanin(absl::string_view node_name, 203 absl::string_view fanin_node_name); 204 205 // Removes all fanins from node `node_name`. Control dependencies will be 206 // retained if keep_controlling_fanins is true. 207 // 208 // If no fanins are removed, this will not result in an error and the node 209 // will not be modified. 210 Status RemoveAllFanins(absl::string_view node_name, 211 bool keep_controlling_fanins); 212 213 // Replaces all fanins `from_fanin` with `to_fanin` in node `node_name`. If 214 // the fanins or node do not exist, nothing will be modified in the graph. 215 // Control dependencies will be deduped. 216 // 217 // If the fanin being updated doesn't exist in the node's inputs, this will 218 // not result in an error and the node will not be modified. 219 Status UpdateFanin(absl::string_view node_name, const TensorId& from_fanin, 220 const TensorId& to_fanin); 221 222 // Replaces fanin at port `port` in node `node_name` with fanin `fanin`. If 223 // the fanins or node do not exist, nothing will be modified in the graph. 224 // Control dependencies will be deduped. 225 // 226 // If the port is not a valid port (less than 0 or greater than the last index 227 // of the regular fanins), this will result in an error and the node will not 228 // be modified. 229 Status UpdateRegularFaninByPort(absl::string_view node_name, int port, 230 const TensorId& fanin); 231 232 // Swaps fanins at ports `from_port` and `to_port` in node `node_name`. If the 233 // node does not exist, nothing will be modified in the graph. 234 // 235 // If the ports are not a valid port (less than 0 or greater than the last 236 // index of the regular fanins), this will result in an error and the node 237 // will not be modified. 238 Status SwapRegularFaninsByPorts(absl::string_view node_name, int from_port, 239 int to_port); 240 241 // Updates all regular fanins to equivalent controlling fanins. If it is not 242 // possible, an error will be returned and nothing will be modified in the 243 // graph. 244 Status UpdateAllRegularFaninsToControlling(absl::string_view node_name); 245 246 // Deletes nodes from the graph. If a node can't be safely removed, 247 // specifically if a node still has fanouts, an error will be returned. Nodes 248 // that can't be found are ignored. 249 Status DeleteNodes(const absl::flat_hash_set<string>& nodes_to_delete); 250 251 private: 252 // Adds fanouts for fanins of node to graph, while deduping control 253 // dependencies from existing control dependencies and regular fanins. Note, 254 // node inputs will be mutated if control dependencies can be deduped. 255 void AddAndDedupFanouts(NodeDef* node); 256 257 // Finds next output port smaller than fanin.port_id and update. The 258 // max_regular_output_port is only updated if fanin.port_id is the same as the 259 // current max_regular_output_port and if the fanouts set is empty. If there 260 // are no regular outputs, max_regular_output_port will be erased. 261 void UpdateMaxRegularOutputPortForRemovedFanin( 262 const OutputPort& fanin, 263 const absl::flat_hash_set<InputPort>& fanin_fanouts); 264 265 // Updates max regular output port for newly added fanin by checking the 266 // current max and updating if the newly added fanin is of a larger port. 267 void UpdateMaxRegularOutputPortForAddedFanin(const OutputPort& fanin); 268 269 // Updates all fanouts (input ports fetching output tensors) from `from_node` 270 // to the `to_node`, including control dependencies. 271 // 272 // Example: We have 3 nodes that use `bar` node output tensors as inputs: 273 // 1. foo1(bar:0, bar:1, other:0) 274 // 2. foo2(bar:1, other:1) 275 // 3. foo3(other:2, ^bar) 276 // 277 // After calling UpdateFanouts(bar, new_bar): 278 // 1. foo1(new_bar:0, new_bar:1, other:0) 279 // 2. foo2(new_bar:1, other:1) 280 // 3. foo3(other:2, ^new_bar) 281 // 282 // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the 283 // behavior is undefined. 284 Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node); 285 286 // Adds fanin to node. If fanin is a control dependency, existing control 287 // dependencies will be checked first before adding. Otherwise fanin will be 288 // added after existing non control dependency inputs. 289 bool AddFaninInternal(NodeDef* node, const OutputPort& fanin); 290 291 // Finds control dependency node to be used based on fanin. If fanin is not a 292 // Switch node, fanin.node is simply returned. Otherwise this will try to find 293 // a candidate Identity node consuming fanin, as the control dependency. If it 294 // is not possible or will introduce a self loop, an error message will be 295 // set. If nullptr is returned with no error 296 // GetOrCreateIdentityConsumingSwitch should be called to generate the new 297 // Identity node. 298 NodeDef* GetControllingFaninToAdd(absl::string_view node_name, 299 const OutputPort& fanin, string* error_msg); 300 301 // Finds a generated Identity node consuming Switch node `fanin.node` at port 302 // `fanin.port_id`. If such a node does not exist, a new Identity node will be 303 // created. 304 NodeDef* GetOrCreateIdentityConsumingSwitch(const OutputPort& fanin); 305 306 // Removes all instances of regular fanin `fanin` from node `node`. 307 bool RemoveRegularFaninInternal(NodeDef* node, const OutputPort& fanin); 308 309 // Removes controlling fanin `fanin_node` from node if such controlling fanin 310 // exists. 311 bool RemoveControllingFaninInternal(NodeDef* node, NodeDef* fanin_node); 312 313 // Checks if nodes to be deleted are missing or have any fanouts that will 314 // remain in the graph. If node is removed in either case, the graph will 315 // enter an invalid state. 316 Status CheckNodesCanBeDeleted( 317 const absl::flat_hash_set<string>& nodes_to_delete); 318 319 // Removes fanins of the deleted node from internal state. Control 320 // dependencies are retained iff keep_controlling_fanins is true. 321 void RemoveFaninsInternal(NodeDef* deleted_node, 322 bool keep_controlling_fanins); 323 324 // Removes fanouts of the deleted node from internal state. 325 void RemoveFanoutsInternal(NodeDef* deleted_node); 326 }; 327 328 } // end namespace grappler 329 } // end namespace tensorflow 330 331 #endif // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ 332