1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_ 18 19 #include "tensorflow/core/grappler/utils/graph_view.h" 20 21 namespace tensorflow { 22 namespace grappler { 23 namespace utils { 24 25 //------------------------------------------------------------------------------ 26 // A pattern can be defined by the following grammar. Here, op_type is any valid 27 // op name in the TensorFlow. 28 // 29 // leaf_pattern ::= `{` op_type `}` 30 // pattern ::= leaf_pattern | 31 // `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}` 32 // 33 // (1) For example, the following pattern syntax describes a pattern for 34 // _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op. 35 // 36 // {"Relu", 37 // { 38 // "BiasAdd", 39 // { 40 // {"Conv2D"}, 41 // {"*"} 42 // } 43 // } 44 // } 45 // 46 // The syntax above has a root ("Relu") and children (inputs), where each child 47 // is a sub-pattern. Graph pattern matcher finds a match for the given pattern 48 // syntax in a graph and returns a set of matched nodes. 49 // 50 // (2) In order to match a DAG with a given root, we extend pattern syntax with 51 // labels. For example, a frequently found pattern in Deep Learning models is a 52 // residual block like below. 53 // 54 // Placeholder Const 55 // | | 56 // +-----+-----+ | 57 // | | | 58 // | v v 59 // | Conv2D Const 60 // | | | 61 // | v v-----+ 62 // | BiasAdd 63 // | | 64 // v v----------+ 65 // AddV2 66 // 67 // As shown above, it is the same input node (Placeholder) consumed by both 68 // AddV2 and and Conv2D. This constrained can be put as labels in the following 69 // augmented pattern syntax. 70 // 71 // {"AddV2", "my_add", 72 // { 73 // {"*", "my_residual_input"}, 74 // {"BiasAdd", "my_bias_add", 75 // { 76 // {"Conv2D", "my_conv", 77 // { 78 // {"*", "my_residual_input"}, 79 // {"*", "my_filter"} 80 // } 81 // }, 82 // {"*", my_bias"} 83 // } 84 // } 85 // } 86 // } 87 // 88 // Note that the same label "my_residual_input" is used to tell that it is a 89 // child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate 90 // with the nodes to be matched as well as to uniquely identify those nodes. 91 // 92 // (3) The motivatation for a grammar based pattern matching in grappler is to 93 // make easy for finding fusion pattern in the remapper. A subgraph that 94 // matches a given pattern, however, is not fusable if any of the matched node, 95 // that will be removed as a part of fusion, has a consumer outside the matched 96 // subgraph. In order to check for such type of external dependencies, we 97 // further extend pattern syntax by prospective action (NodeStatus) on the 98 // matched nodes as shown below. This helps cross checking the nodes to be 99 // removed with the nodes matched intially. 100 // 101 // {"AddV2", "my_add", NodeStatus::kReplace, 102 // { 103 // {"*", "my_residual_input", NodeStatus::kRemain}, 104 // {"BiasAdd", "my_bias_add", NodeStatus::kRemove, 105 // { 106 // {"Conv2D", "my_conv", NodeStatus::kRemove, 107 // { 108 // {"*", "my_residual_input", NodeStatus::kRemain}, 109 // {"*", "my_filter", NodeStatus::Remain} 110 // } 111 // }, 112 // {"*", my_bias", NodeStatus::kRemain} 113 // } 114 // } 115 // } 116 // } 117 //------------------------------------------------------------------------------ 118 119 // Pattern matcher recursively matches child subpatterns. The direction 120 // for children could be toward node's input (fanins) or outputs (fanouts). 121 enum class MatchingDirection { kFollowInputs, kFollowOutputs }; 122 123 // Action for each node in the set of matched nodes for a given pattern. 124 enum class NodeStatus { kRemain, kRemove, kReplace }; 125 126 // TODO (intel-tf): Support multiple roots by making them children of a single 127 // virtual root. 128 struct OpTypePattern { 129 string op; 130 string label; 131 NodeStatus node_status; 132 std::vector<OpTypePattern> children; 133 DebugStringOpTypePattern134 string DebugString() const { 135 string result = "{(op: " + op + ", " + "label: " + label + "), {"; 136 for (const OpTypePattern& child : children) { 137 result += child.DebugString() + ","; 138 } 139 result += "}}"; 140 return result; 141 } 142 }; 143 144 // This is a helpful recursive structure that keeps one-to-one mapping of 145 // pattern syntax to the matched nodes. User can call DebugString to see what 146 // has been matched so far and where is the failing point. 147 struct NodeViewMatch { 148 MutableNodeView* node_view = nullptr; 149 std::vector<NodeViewMatch> children; 150 DebugStringNodeViewMatch151 string DebugString() const { 152 string result = "{"; 153 if (node_view == nullptr) { 154 result += "Non-Matched-Node}"; 155 return result; 156 } else { 157 result += node_view->node()->DebugString(); 158 result += ", {"; 159 for (const NodeViewMatch& child : children) { 160 result += child.DebugString() + ","; 161 } 162 result += "}}"; 163 return result; 164 } 165 } 166 ClearNodeViewMatch167 void Clear() { 168 for (NodeViewMatch& child : children) { 169 child.Clear(); // child is an object. 170 } 171 children.clear(); // children is a vector. 172 if (node_view != nullptr) { 173 node_view = nullptr; 174 } 175 } 176 }; 177 178 template <MatchingDirection DIRECTION = MatchingDirection::kFollowInputs> 179 class SubGraphMatcher { 180 public: SubGraphMatcher(MutableGraphView * graph_view)181 SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){}; 182 183 // If a given pattern is matched, this function returns true as well as the 184 // matched node and remove node info is populated. 185 bool GetMatchedNodes(const OpTypePattern& pattern, 186 const std::unordered_set<string>& nodes_to_preserve, 187 MutableNodeView* node_view, 188 std::map<string, int>* matched_nodes_map, 189 std::set<int>* remove_node_indices); 190 191 private: 192 MutableGraphView* graph_view_; 193 std::map<string, int> node_label_to_index_; 194 std::set<int> matched_node_indices_; 195 std::set<int> remove_node_indices_; 196 std::unique_ptr<NodeViewMatch> match_ = nullptr; 197 198 bool DoesOpTypePatternMatch(const OpTypePattern& pattern, 199 MutableNodeView* node_view, NodeViewMatch* match); 200 201 // This function should be called after the pattern matcher has found 202 // potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true"). 203 // It performs a sanity check if the candidate nodes for removal in subgraph 204 // fusion is indeed safe to remove. IsSafeNodesToRemove(const std::unordered_set<string> & nodes_to_preserve)205 bool IsSafeNodesToRemove( 206 const std::unordered_set<string>& nodes_to_preserve) { 207 for (const auto& node_idx : remove_node_indices_) { 208 auto node_view = graph_view_->GetNode(node_idx); 209 // Check if the node to be removed is in the nodes to be preserved. 210 string node_name = node_view->GetName(); 211 if (nodes_to_preserve.count(node_name) > 0) return false; 212 // Traverse all the Regular Fanouts. Fanouts are stored as vector of 213 // vector, std::vector<std::vector<MutableFaninView>>. Note that 214 // a MutableNodeView's fanouts are stored in a nested vector of 215 // MutableFaninView type. 216 auto fanouts_by_ports = node_view->GetRegularFanouts(); 217 for (const auto& fanouts : fanouts_by_ports) { 218 for (const auto& fanout : fanouts) { 219 if (!matched_node_indices_.count(fanout.node_index())) { 220 return false; 221 } 222 } 223 } 224 } 225 return true; 226 } 227 }; 228 229 } // namespace utils 230 } // namespace grappler 231 } // namespace tensorflow 232 233 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_ 234