1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ 18 19 #include <unordered_map> 20 #include <unordered_set> 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/hash/hash.h" 24 #include "absl/strings/string_view.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/framework/op_def.pb.h" 28 #include "tensorflow/core/graph/tensor_id.h" 29 #include "tensorflow/core/grappler/utils.h" 30 #include "tensorflow/core/lib/gtl/map_util.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace tensorflow { 34 namespace grappler { 35 36 // Map a node/op's input/output port_id to arg_id. 37 // 38 // The port_id refers to the n-th tensor of the node, while the arg_id refers to 39 // the n-th arg of the op. These two can be different if an op's arg is a list 40 // of tensors. 41 // 42 // We return -1 for any invalid port_id (i.e., no corresponding arg_id). 43 int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); 44 int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); 45 46 namespace internal { 47 48 // GraphViewInternal is a helper class to simplify graph traversal. It creates 49 // an immutable view of the nodes and edges represented by a GraphDef protocol 50 // buffer. 51 // 52 // There are two public classes implementing GraphViewInternal: 53 // 54 // - GraphView: constructed from the `const GraphDef` and doesn't allow 55 // to mutate underlying graph via input/output ports lookup functions (ports 56 // have const pointers to nodes). 57 // 58 // - MutableGraphView: constructed from the 'GraphDef` and allows to mutate 59 // the graph via input/output ports lookup functions (ports have non-const 60 // pointers to nodes), and also have couple additional functions to 61 // add/remove/replace nodes in the graph. 62 // 63 // --------------------------- !!! WARNING !!! --------------------------------- 64 // Removing nodes from the graph outside of MutableGraphView will 65 // lead to segfaults! Guaranteed by absl::string_view! 66 // ----------------------------------------------------------------------------- 67 // 68 template <typename GraphDefT, typename NodeDefT> 69 class GraphViewInternal { 70 public: 71 struct Port { PortPort72 Port() : node(nullptr), port_id(0) {} PortPort73 Port(NodeDefT* n, int port) : node(n), port_id(port) {} 74 75 bool operator==(const Port& other) const { 76 return node == other.node && port_id == other.port_id; 77 } 78 79 template <typename H> AbslHashValuePort80 friend H AbslHashValue(H h, const Port& p) { 81 return H::combine(std::move(h), p.node, p.port_id); 82 } 83 84 NodeDefT* node; 85 int port_id; 86 }; 87 88 struct InputPort : public Port { 89 using Port::Port; 90 }; 91 92 struct OutputPort : public Port { 93 using Port::Port; 94 }; 95 96 struct Edge { EdgeEdge97 Edge(OutputPort s, InputPort d) : src(s), dst(d) {} 98 99 bool operator==(const Edge& other) const { 100 return src == other.src && dst == other.dst; 101 } 102 103 template <typename H> AbslHashValueEdge104 friend H AbslHashValue(H h, const Edge& e) { 105 return H::combine(std::move(h), e.src, e.dst); 106 } 107 108 OutputPort src; 109 InputPort dst; 110 }; 111 graph()112 GraphDefT* graph() const { return graph_; } 113 114 // Finds a node by name or return `nullptr` if it's not in the graph view. GetNode(absl::string_view node_name)115 NodeDefT* GetNode(absl::string_view node_name) const { 116 return gtl::FindWithDefault(nodes_, node_name, nullptr); 117 } 118 119 // Checks if a node by name is in the graph view. HasNode(absl::string_view node_name)120 bool HasNode(absl::string_view node_name) const { 121 return GetNode(node_name) != nullptr; 122 } 123 124 // Gets the specified input port. Note that the special '-1' port_id can be 125 // used to access the controlling nodes (i.e. the nodes connected to node_name 126 // through an incoming control dependency). GetInputPort(absl::string_view node_name,int port_id)127 InputPort GetInputPort(absl::string_view node_name, int port_id) const { 128 return InputPort(GetNode(node_name), port_id); 129 } 130 131 // Gets the specified output port. Note that the special '-1' port_id can be 132 // used to access the controlled nodes (i.e. the nodes connected to node_name 133 // through an outgoing control dependency). GetOutputPort(absl::string_view node_name,int port_id)134 OutputPort GetOutputPort(absl::string_view node_name, int port_id) const { 135 return OutputPort(GetNode(node_name), port_id); 136 } 137 138 // Gets the input port(s) in the immediate fanout of an output port. GetFanout(const OutputPort & port)139 const absl::flat_hash_set<InputPort>& GetFanout( 140 const OutputPort& port) const { 141 return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_); 142 } 143 144 // Gets the output port(s) in the immediate fanin of an input port. GetFanin(const InputPort & port)145 absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const { 146 if (port.port_id >= 0) { 147 OutputPort regular_fanin = GetRegularFanin(port); 148 if (regular_fanin.node == nullptr) { 149 return {}; 150 } 151 return {regular_fanin}; 152 } 153 154 // Collect fanin for the control input. 155 absl::flat_hash_set<OutputPort> result; 156 const int first_control_port = 157 gtl::FindWithDefault(max_regular_input_port_, port.node, -1) + 1; 158 for (int i = first_control_port; i < port.node->input_size(); ++i) { 159 TensorId tensor_id = ParseTensorName(port.node->input(i)); 160 161 auto it = nodes_.find(tensor_id.node()); 162 if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); 163 } 164 return result; 165 } 166 167 // Special case: regular (i.e. non-control) input ports can only have one 168 // fanin. If port.port_id is out of range or is a control dependency, then an 169 // empty OutputPort is returned. GetRegularFanin(const InputPort & port)170 const OutputPort GetRegularFanin(const InputPort& port) const { 171 if (port.port_id < 0 || 172 port.port_id > 173 gtl::FindWithDefault(max_regular_input_port_, port.node, -1)) { 174 return OutputPort(); 175 } 176 177 TensorId tensor_id = ParseTensorName(port.node->input(port.port_id)); 178 return GetOutputPort(tensor_id.node(), tensor_id.index()); 179 } 180 181 // Checks if a tensor id is a fanin of the node. HasFanin(const NodeDefT & node,const TensorId & fanin)182 bool HasFanin(const NodeDefT& node, const TensorId& fanin) const { 183 int end = node.input_size(); 184 if (end == 0 || fanin.index() < -1) { 185 return false; 186 } 187 188 const int num_regular_fanins = 189 gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1; 190 int start = 0; 191 if (fanin.index() > -1) { 192 end = num_regular_fanins; 193 } else { 194 start = num_regular_fanins; 195 } 196 for (int i = start; i < end; ++i) { 197 if (ParseTensorName(node.input(i)) == fanin) { 198 return true; 199 } 200 } 201 return false; 202 } 203 204 // Gets all the input ports in the immediate fanout of a node. Include the 205 // controlled nodes iff include_controlled_nodes is true. GetFanouts(const NodeDefT & node,bool include_controlled_nodes)206 absl::flat_hash_set<InputPort> GetFanouts( 207 const NodeDefT& node, bool include_controlled_nodes) const { 208 absl::flat_hash_set<InputPort> result; 209 210 OutputPort port; 211 port.node = const_cast<NodeDefT*>(&node); 212 const int first_port_id = include_controlled_nodes ? -1 : 0; 213 const int last_port_id = 214 gtl::FindWithDefault(max_regular_output_port_, &node, -1); 215 216 for (int i = first_port_id; i <= last_port_id; ++i) { 217 port.port_id = i; 218 auto it = fanouts_.find(port); 219 if (it != fanouts_.end()) { 220 result.insert(it->second.begin(), it->second.end()); 221 } 222 } 223 return result; 224 } 225 226 // Gets all the output ports in the immediate fanin of a node. Include the 227 // controlling nodes iff include_controlling_nodes is true. GetFanins(const NodeDefT & node,bool include_controlling_nodes)228 absl::flat_hash_set<OutputPort> GetFanins( 229 const NodeDefT& node, bool include_controlling_nodes) const { 230 absl::flat_hash_set<OutputPort> result; 231 const int max_input_port = 232 include_controlling_nodes 233 ? node.input_size() - 1 234 : gtl::FindWithDefault(max_regular_input_port_, &node, -1); 235 for (int i = 0; i <= max_input_port; ++i) { 236 TensorId tensor_id = ParseTensorName(node.input(i)); 237 238 auto it = nodes_.find(tensor_id.node()); 239 if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); 240 } 241 return result; 242 } 243 244 // Gets the number of ports in the immediate fanin of a node. Count the 245 // controlling nodes iff include_controlling_nodes is true. NumFanins(const NodeDefT & node,bool include_controlling_nodes)246 int NumFanins(const NodeDefT& node, bool include_controlling_nodes) const { 247 if (include_controlling_nodes) { 248 return node.input_size(); 249 } 250 return gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1; 251 } 252 253 // Gets the number of ports in the immediate fanout of a node. Count the 254 // controlled nodes iff include_controlled_nodes is true. NumFanouts(const NodeDefT & node,bool include_controlled_nodes)255 int NumFanouts(const NodeDefT& node, bool include_controlled_nodes) const { 256 int count = 0; 257 258 OutputPort port; 259 port.node = const_cast<NodeDefT*>(&node); 260 const int first_port_id = include_controlled_nodes ? -1 : 0; 261 const int last_port_id = 262 gtl::FindWithDefault(max_regular_output_port_, &node, -1); 263 264 for (int i = first_port_id; i <= last_port_id; ++i) { 265 port.port_id = i; 266 auto it = fanouts_.find(port); 267 if (it != fanouts_.end()) count += it->second.size(); 268 } 269 270 return count; 271 } 272 273 // Gets all the edges in the immediate fanout of a node. Include the 274 // controlled edges iff include_controlled_edges is true. GetFanoutEdges(const NodeDefT & node,bool include_controlled_edges)275 absl::flat_hash_set<Edge> GetFanoutEdges( 276 const NodeDefT& node, bool include_controlled_edges) const { 277 absl::flat_hash_set<Edge> result; 278 279 OutputPort port; 280 port.node = const_cast<NodeDefT*>(&node); 281 const int first_port_id = include_controlled_edges ? -1 : 0; 282 const int last_port_id = 283 gtl::FindWithDefault(max_regular_output_port_, &node, -1); 284 285 for (int i = first_port_id; i <= last_port_id; ++i) { 286 port.port_id = i; 287 auto it = fanouts_.find(port); 288 if (it != fanouts_.end()) { 289 for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { 290 result.emplace(/*src=*/port, /*dst=*/*itr); 291 } 292 } 293 } 294 return result; 295 } 296 297 // Gets all the edges in the immediate fanin of a node. Include the 298 // controlling edges iff include_controlling_edges is true. GetFaninEdges(const NodeDefT & node,bool include_controlling_edges)299 absl::flat_hash_set<Edge> GetFaninEdges( 300 const NodeDefT& node, bool include_controlling_edges) const { 301 absl::flat_hash_set<Edge> result; 302 const int max_input_port = 303 include_controlling_edges 304 ? node.input_size() - 1 305 : gtl::FindWithDefault(max_regular_input_port_, &node, -1); 306 for (int i = 0; i <= max_input_port; ++i) { 307 TensorId tensor_id = ParseTensorName(node.input(i)); 308 309 auto it = nodes_.find(tensor_id.node()); 310 if (it != nodes_.end()) { 311 result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()), 312 /*dst=*/InputPort(const_cast<NodeDefT*>(&node), i)); 313 } 314 } 315 return result; 316 } 317 318 protected: GraphViewInternal(GraphDefT * graph)319 explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} 320 AddUniqueNode(NodeDefT * node)321 Status AddUniqueNode(NodeDefT* node) { 322 auto inserted = nodes_.emplace(node->name(), node); 323 return inserted.second 324 ? Status::OK() 325 : errors::InvalidArgument("Non unique node name detected: ", 326 node->name()); 327 } 328 329 // TODO(ezhulenev): Remove this function. AddUniqueNodeOrDie(NodeDefT * node)330 void AddUniqueNodeOrDie(NodeDefT* node) { 331 Status st = AddUniqueNode(node); 332 CHECK(st.ok()) << st.error_message(); 333 } 334 335 // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins 336 // exist, and all regular fanins come before controlling fanins. AddFanouts(NodeDefT * node)337 void AddFanouts(NodeDefT* node) { 338 int max_input_port = -1; 339 for (int i = 0; i < node->input_size(); ++i) { 340 TensorId tensor_id = ParseTensorName(node->input(i)); 341 OutputPort output(nodes_[tensor_id.node()], tensor_id.index()); 342 343 if (output.port_id < 0) { 344 fanouts_[output].emplace(node, -1); 345 } else { 346 max_input_port = i; 347 max_regular_output_port_[output.node] = 348 std::max(max_regular_output_port_[output.node], output.port_id); 349 fanouts_[output].emplace(node, i); 350 } 351 } 352 if (max_input_port > -1) { 353 max_regular_input_port_[node] = max_input_port; 354 } 355 } 356 357 // Access to the mutable internal state for MutableGraphView. nodes()358 absl::flat_hash_map<absl::string_view, NodeDefT*>& nodes() { return nodes_; } 359 fanouts()360 absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>& fanouts() { 361 return fanouts_; 362 } 363 max_regular_input_port()364 absl::flat_hash_map<const NodeDefT*, int>& max_regular_input_port() { 365 return max_regular_input_port_; 366 } 367 max_regular_output_port()368 absl::flat_hash_map<const NodeDefT*, int>& max_regular_output_port() { 369 return max_regular_output_port_; 370 } 371 372 private: 373 GraphDefT* graph_; // must outlive the graph view 374 375 // A mapping from the node name to the node itself. 376 absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_; 377 378 // A mapping from the output port to all inputs that read from it. 379 absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_; 380 381 // Keep a maximum index of input tensors of the node. 382 absl::flat_hash_map<const NodeDefT*, int> max_regular_input_port_; 383 384 // Keep a maximum index of tensor fetched from the node. It doesn't guarantee 385 // that all tensors in the [0, max_regular_output_port] range are actually 386 // fetched by other nodes. 387 absl::flat_hash_map<const NodeDefT*, int> max_regular_output_port_; 388 389 // If the node has no fanouts at given output port (output tensor consumers) 390 // we return a reference to this set from `GetFanout` (we can't construct new 391 // empty set every time, because we need a non-dangling reference). 392 absl::flat_hash_set<InputPort> fanout_not_found_value_; 393 }; 394 395 } // namespace internal 396 397 // Immutable GraphView that keeps the constness of the GraphDef. If you need to 398 // mutate the graph or the nodes via the graph view lookup functions, see 399 // MutableGraphView. 400 class GraphView 401 : public internal::GraphViewInternal<const GraphDef, const NodeDef> { 402 public: GraphView(const GraphDef * graph)403 explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) { 404 for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node); 405 for (const NodeDef& node : graph->node()) AddFanouts(&node); 406 } 407 }; 408 409 // Returns true if node has one (or zero) fanout nodes at given output port. 410 bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node, 411 int port = 0); 412 413 // Returns true if node has at least one fanout node at given output port. 414 bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0); 415 // Returns true if the node has at least one input control dependency. 416 bool HasControlFanin(const GraphView& graph_view, const NodeDef* node); 417 // Returns true if the node has at least one output control dependency. 418 bool HasControlFanout(const GraphView& graph_view, const NodeDef* node); 419 // Returns true if the node has at least one input or output control dependency. 420 bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node); 421 422 } // end namespace grappler 423 } // end namespace tensorflow 424 425 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ 426