1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ 17 #define TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ 18 19 #include <functional> 20 #include <unordered_set> 21 #include <vector> 22 23 #include "tensorflow/core/graph/graph.h" 24 #include "tensorflow/core/lib/gtl/array_slice.h" 25 26 namespace tensorflow { 27 28 // Comparator for two nodes. This is used in order to get a stable ording. 29 using NodeComparator = std::function<bool(const Node*, const Node*)>; 30 31 using EdgeFilter = std::function<bool(const Edge&)>; 32 33 // Compares two node based on their ids. 34 struct NodeComparatorID { operatorNodeComparatorID35 bool operator()(const Node* n1, const Node* n2) const { 36 return n1->id() < n2->id(); 37 } 38 }; 39 40 // Compare two nodes based on their names. 41 struct NodeComparatorName { operatorNodeComparatorName42 bool operator()(const Node* n1, const Node* n2) const { 43 return n1->name() < n2->name(); 44 } 45 }; 46 47 // Perform a depth-first-search on g starting at the source node. 48 // If enter is not empty, calls enter(n) before visiting any children of n. 49 // If leave is not empty, calls leave(n) after visiting all children of n. 50 // If stable_comparator is set, a stable ordering of visit is achieved by 51 // sorting a node's neighbors first before visiting them. 52 // If edge_filter is set then ignores edges for which edge_filter returns false. 53 extern void DFS(const Graph& g, const std::function<void(Node*)>& enter, 54 const std::function<void(Node*)>& leave, 55 const NodeComparator& stable_comparator = {}, 56 const EdgeFilter& edge_filter = {}); 57 58 // Perform a depth-first-search on g starting at the 'start' nodes. 59 // If enter is not empty, calls enter(n) before visiting any children of n. 60 // If leave is not empty, calls leave(n) after visiting all children of n. 61 // If stable_comparator is set, a stable ordering of visit is achieved by 62 // sorting a node's neighbors first before visiting them. 63 // If edge_filter is set then ignores edges for which edge_filter returns false. 64 extern void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, 65 const std::function<void(Node*)>& enter, 66 const std::function<void(Node*)>& leave, 67 const NodeComparator& stable_comparator = {}, 68 const EdgeFilter& edge_filter = {}); 69 extern void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, 70 const std::function<void(const Node*)>& enter, 71 const std::function<void(const Node*)>& leave, 72 const NodeComparator& stable_comparator = {}, 73 const EdgeFilter& edge_filter = {}); 74 75 // Perform a reverse depth-first-search on g starting at the sink node. 76 // If enter is not empty, calls enter(n) before visiting any parents of n. 77 // If leave is not empty, calls leave(n) after visiting all parents of n. 78 // If stable_comparator is set, a stable ordering of visit is achieved by 79 // sorting a node's neighbors first before visiting them. 80 extern void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter, 81 const std::function<void(Node*)>& leave, 82 const NodeComparator& stable_comparator = {}); 83 84 // Perform a reverse depth-first-search on g starting at the 'start' nodes. 85 // If enter is not empty, calls enter(n) before visiting any parents of n. 86 // If leave is not empty, calls leave(n) after visiting all parents of n. 87 // If stable_comparator is set, a stable ordering of visit is achieved by 88 // sorting a node's neighbors first before visiting them. 89 extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, 90 const std::function<void(Node*)>& enter, 91 const std::function<void(Node*)>& leave, 92 const NodeComparator& stable_comparator = {}); 93 extern void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, 94 const std::function<void(const Node*)>& enter, 95 const std::function<void(const Node*)>& leave, 96 const NodeComparator& stable_comparator = {}); 97 98 // Stores in *order the post-order numbering of all nodes 99 // in graph found via a depth first search starting at the source node. 100 // 101 // Note that this is equivalent to reverse topological sorting when the 102 // graph does not have cycles. 103 // 104 // If stable_comparator is set, a stable ordering of visit is achieved by 105 // sorting a node's neighbors first before visiting them. 106 // 107 // If edge_filter is set then ignores edges for which edge_filter returns false. 108 // 109 // REQUIRES: order is not NULL. 110 void GetPostOrder(const Graph& g, std::vector<Node*>* order, 111 const NodeComparator& stable_comparator = {}, 112 const EdgeFilter& edge_filter = {}); 113 114 // Stores in *order the reverse post-order numbering of all nodes 115 // If stable_comparator is set, a stable ordering of visit is achieved by 116 // sorting a node's neighbors first before visiting them. 117 // 118 // If edge_filter is set then ignores edges for which edge_filter returns false. 119 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, 120 const NodeComparator& stable_comparator = {}, 121 const EdgeFilter& edge_filter = {}); 122 123 // Prune nodes in "g" that are not in some path from the source node 124 // to any node in 'nodes'. Returns true if changes were made to the graph. 125 // Does not fix up source and sink edges. 126 bool PruneForReverseReachability(Graph* g, 127 std::unordered_set<const Node*> nodes); 128 129 // Connect all nodes with no incoming edges to source. 130 // Connect all nodes with no outgoing edges to sink. 131 // 132 // Returns true if and only if 'g' is mutated. 133 bool FixupSourceAndSinkEdges(Graph* g); 134 135 } // namespace tensorflow 136 137 #endif // TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ 138