1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ 18 19 #include <functional> 20 21 #include "tensorflow/core/grappler/graph_topology_view.h" 22 23 namespace tensorflow { 24 namespace grappler { 25 26 enum class TraversalDirection { kFollowInputs, kFollowOutputs }; 27 28 // Encapsulate DFS callbacks that will be called during the graph traversal. 29 // 30 // If non-empty, the `pre_order` and `post_order` functors will be called on 31 // each reachable node (including the `from` nodes) in pre and post order. If 32 // loops are found, the `on_back_edge` functor will be called on the 33 // corresponding back edges. Moreover, the pre and post order will assume that 34 // these back edges will be cut. 35 struct DfsCallbacks { 36 DfsCallbacks() = default; DfsCallbacksDfsCallbacks37 DfsCallbacks(std::function<void(const NodeDef*)> pre, 38 std::function<void(const NodeDef*)> post, 39 std::function<void(const NodeDef*, const NodeDef*)> back_edge) 40 : pre_order(std::move(pre)), 41 post_order(std::move(post)), 42 on_back_edge(std::move(back_edge)) {} 43 PreOrderDfsCallbacks44 static DfsCallbacks PreOrder(std::function<void(const NodeDef*)> pre) { 45 return DfsCallbacks(std::move(pre), nullptr, nullptr); 46 } 47 PostOrderDfsCallbacks48 static DfsCallbacks PostOrder(std::function<void(const NodeDef*)> post) { 49 return DfsCallbacks(nullptr, std::move(post), nullptr); 50 } 51 52 std::function<void(const NodeDef*)> pre_order; 53 std::function<void(const NodeDef*)> post_order; 54 std::function<void(const NodeDef*, const NodeDef*)> on_back_edge; 55 }; 56 57 // Encapsulate DFS predicates for traversing the graph. 58 // 59 // The `enter` predicate decides if traversal should enter the node, and the 60 // `advance` predicate decides if the traversal should follow inputs/outputs 61 // from the node. 62 // 63 // If predicates are empty (default initialized), it's assumed that we can enter 64 // into any node and advance from any node respectively. 65 struct DfsPredicates { 66 DfsPredicates() = default; DfsPredicatesDfsPredicates67 DfsPredicates(std::function<bool(const NodeDef*)> enter, 68 std::function<bool(const NodeDef*)> advance) 69 : enter(std::move(enter)), advance(std::move(advance)) {} 70 EnterDfsPredicates71 static DfsPredicates Enter(std::function<bool(const NodeDef*)> enter) { 72 return DfsPredicates(std::move(enter), nullptr); 73 } 74 AdvanceDfsPredicates75 static DfsPredicates Advance(std::function<bool(const NodeDef*)> advance) { 76 return DfsPredicates(nullptr, std::move(advance)); 77 } 78 79 std::function<bool(const NodeDef*)> enter; 80 std::function<bool(const NodeDef*)> advance; 81 }; 82 83 // Traverse the graph in DFS order in the given direction, starting from the 84 // list of nodes specified in the `from` argument. Use `predicates` to decide if 85 // traversal should enter/advance to/from the graph node. These predicates also 86 // applied to the `from` nodes. Call corresponding callbacks for each visited 87 // node. 88 void DfsTraversal(const GraphTopologyView& graph_view, 89 absl::Span<const NodeDef* const> from, 90 TraversalDirection direction, const DfsPredicates& predicates, 91 const DfsCallbacks& callbacks); 92 93 // Traverse the graph in DFS order in the given direction, starting from the 94 // list of nodes specified in the `from` argument. Call corresponding callbacks 95 // for each visited node. 96 void DfsTraversal(const GraphTopologyView& graph_view, 97 absl::Span<const NodeDef* const> from, 98 TraversalDirection direction, const DfsCallbacks& callbacks); 99 100 } // namespace grappler 101 } // namespace tensorflow 102 103 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ 104