1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/inlined_vector.h" 21 #include "absl/strings/string_view.h" 22 #include "absl/types/optional.h" 23 #include "absl/types/span.h" 24 #include "tensorflow/core/graph/tensor_id.h" 25 #include "tensorflow/core/grappler/graph_view.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 30 // GraphTopologyView is a helper class to simplify `node-to-node` connectivity 31 // traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals: 32 // connections between output tensors and inputs of a consumer nodes. For the 33 // topology view we are focused on nodes connected to nodes, and it's irrelevant 34 // if this connection is formed by one or multiple individual tensors. 35 // 36 // Example: 37 // a = Placeholder(..) 38 // b = Placeholder(..) 39 // c = AddN([a, a, b]) 40 // 41 // GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:2] 42 // GraphTopologyView edges: [a -> c, b -> c] 43 // 44 // GraphView is used for exploring single node fanins and fanouts, and 45 // GraphTopologyView is focused on efficient full graph traversals (computing 46 // graph node properties from transitive fanouts, etc...). 47 class GraphTopologyView { 48 public: 49 GraphTopologyView() = default; GraphTopologyView(bool skip_invalid_edges)50 explicit GraphTopologyView(bool skip_invalid_edges) 51 : skip_invalid_edges_(skip_invalid_edges) {} 52 53 // Initialize graph topology view from the graph. It's possible to pass 54 // additional edges that do not exist in a graph, but must be respected when 55 // computing graph topology. Example: Tensorflow runtime allows concurrent 56 // execution of dequeue/enqueue ops from the same queue resource, but we might 57 // want to enforce ordering between them for the purpose of graph analysis. 58 Status InitializeFromGraph(const GraphDef& graph, 59 absl::Span<const GraphView::Edge> ephemeral_edges, 60 bool ignore_control_edges); 61 Status InitializeFromGraph(const GraphDef& graph, 62 absl::Span<const GraphView::Edge> ephemeral_edges); 63 Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges); 64 Status InitializeFromGraph(const GraphDef& graph); 65 is_initialized()66 bool is_initialized() const { return graph_ != nullptr; } num_nodes()67 int num_nodes() const { return num_nodes_; } graph()68 const GraphDef* graph() const { return graph_; } 69 70 // Returns true iff the node exists in the underlying graph. 71 bool HasNode(absl::string_view node_name) const; 72 73 // Finds a node by name or returns `nullptr` if it's not in the graph. 74 const NodeDef* GetNode(absl::string_view node_name) const; 75 // Returns a node corresponding to the given node index. 76 const NodeDef* GetNode(int node_idx) const; 77 78 // Returns a node index for the given node name, if the name exists in the 79 // underlying graph. Otherwise returns empty optional. 80 const absl::optional<int> GetNodeIndex(absl::string_view node_name) const; 81 // Returns a node index for the given node, if the node belongs to the 82 // underlying graph. Otherwise returns empty optional. 83 const absl::optional<int> GetNodeIndex(const NodeDef& node) const; 84 85 // Returns all the node indexes that are in the direct fanin of the given 86 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. 87 const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const; 88 // Returns all the node indexes that are in the direct fanout of the given 89 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. 90 const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const; 91 92 private: 93 // If true, all invalid edges and inputs (srd, dst or input node not found in 94 // a graph) will be skipped, otherwise initialization will fail with error. 95 bool skip_invalid_edges_ = false; 96 97 // WARN: `graph_` must outlive this object and graph nodes must not be 98 // destructed, because node names captured with absl::string_view. 99 const GraphDef* graph_ = nullptr; // do not own 100 int num_nodes_ = 0; 101 std::vector<absl::string_view> index_to_node_name_; 102 absl::flat_hash_map<absl::string_view, int> node_name_to_index_; 103 std::vector<absl::InlinedVector<int, 4>> fanins_; // node_idx->input nodes 104 std::vector<absl::InlinedVector<int, 2>> fanouts_; // node_idx->output nodes 105 106 // We need a valid reference to return from GetFanin/GetFanout if the 107 // `node_idx` argument is outside of the [0, num_nodes_) range. 108 absl::InlinedVector<int, 4> empty_fanin_; 109 absl::InlinedVector<int, 2> empty_fanout_; 110 }; 111 112 } // end namespace grappler 113 } // end namespace tensorflow 114 115 #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ 116