• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
18 
19 #include "tensorflow/core/grappler/utils/graph_view.h"
20 
21 namespace tensorflow {
22 namespace grappler {
23 namespace utils {
24 
25 //------------------------------------------------------------------------------
26 // A pattern can be defined by the following grammar. Here, op_type is any valid
27 // op name in the TensorFlow.
28 //
29 //    leaf_pattern ::= `{` op_type `}`
30 //    pattern ::= leaf_pattern |
31 //                `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}`
32 //
33 // (1) For example, the following pattern syntax describes a pattern for
34 // _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op.
35 //
36 //  {"Relu",
37 //    {
38 //      "BiasAdd",
39 //      {
40 //        {"Conv2D"},
41 //        {"*"}
42 //      }
43 //    }
44 //  }
45 //
46 // The syntax above has a root ("Relu") and children (inputs), where each child
47 // is a sub-pattern. Graph pattern matcher finds a match for the given pattern
48 // syntax in a graph and returns a set of matched nodes.
49 //
50 // (2) In order to match a DAG with a given root, we extend pattern syntax with
51 // labels. For example, a frequently found pattern in Deep Learning models is a
52 // residual block like below.
53 //
54 //    Placeholder  Const
55 //          |        |
56 //    +-----+-----+  |
57 //    |           |  |
58 //    |           v  v
59 //    |          Conv2D   Const
60 //    |            |        |
61 //    |            v  v-----+
62 //    |          BiasAdd
63 //    |            |
64 //    v v----------+
65 //   AddV2
66 //
67 // As shown above, it is the same input node (Placeholder) consumed by both
68 // AddV2 and and Conv2D. This constrained can be put as labels in the following
69 // augmented pattern syntax.
70 //
71 //  {"AddV2", "my_add",
72 //    {
73 //      {"*", "my_residual_input"},
74 //      {"BiasAdd", "my_bias_add",
75 //        {
76 //          {"Conv2D", "my_conv",
77 //            {
78 //              {"*", "my_residual_input"},
79 //              {"*", "my_filter"}
80 //            }
81 //          },
82 //          {"*", my_bias"}
83 //        }
84 //      }
85 //    }
86 //  }
87 //
88 // Note that the same label "my_residual_input" is used to tell that it is a
89 // child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate
90 // with the nodes to be matched as well as to uniquely identify those nodes.
91 //
92 // (3) The motivatation for a grammar based pattern matching in grappler is to
93 // make easy for finding fusion pattern in the remapper. A subgraph that
94 // matches a given pattern, however, is not fusable if any of the matched node,
95 // that will be removed as a part of fusion, has a consumer outside the matched
96 // subgraph. In order to check for such type of external dependencies, we
97 // further extend pattern syntax by prospective action (NodeStatus) on the
98 // matched nodes as shown below. This helps cross checking the nodes to be
99 // removed with the nodes matched intially.
100 //
101 //  {"AddV2", "my_add", NodeStatus::kReplace,
102 //    {
103 //      {"*", "my_residual_input", NodeStatus::kRemain},
104 //      {"BiasAdd", "my_bias_add", NodeStatus::kRemove,
105 //        {
106 //          {"Conv2D", "my_conv", NodeStatus::kRemove,
107 //            {
108 //              {"*", "my_residual_input", NodeStatus::kRemain},
109 //              {"*", "my_filter", NodeStatus::Remain}
110 //            }
111 //          },
112 //          {"*", my_bias", NodeStatus::kRemain}
113 //        }
114 //      }
115 //    }
116 //  }
117 //------------------------------------------------------------------------------
118 
119 // Pattern matcher recursively matches child subpatterns. The direction
120 // for children could be toward node's input (fanins) or outputs (fanouts).
121 enum class MatchingDirection { kFollowInputs, kFollowOutputs };
122 
123 // Action for each node in the set of matched nodes for a given pattern.
124 enum class NodeStatus { kRemain, kRemove, kReplace };
125 
126 // TODO (intel-tf): Support multiple roots by making them children of a single
127 // virtual root.
128 struct OpTypePattern {
129   string op;
130   string label;
131   NodeStatus node_status;
132   std::vector<OpTypePattern> children;
133 
DebugStringOpTypePattern134   string DebugString() const {
135     string result = "{(op: " + op + ", " + "label: " + label + "), {";
136     for (const OpTypePattern& child : children) {
137       result += child.DebugString() + ",";
138     }
139     result += "}}";
140     return result;
141   }
142 };
143 
144 // This is a helpful recursive structure that keeps one-to-one mapping of
145 // pattern syntax to the matched nodes. User can call DebugString to see what
146 // has been matched so far and where is the failing point.
147 struct NodeViewMatch {
148   MutableNodeView* node_view = nullptr;
149   std::vector<NodeViewMatch> children;
150 
DebugStringNodeViewMatch151   string DebugString() const {
152     string result = "{";
153     if (node_view == nullptr) {
154       result += "Non-Matched-Node}";
155       return result;
156     } else {
157       result += node_view->node()->DebugString();
158       result += ", {";
159       for (const NodeViewMatch& child : children) {
160         result += child.DebugString() + ",";
161       }
162       result += "}}";
163       return result;
164     }
165   }
166 
ClearNodeViewMatch167   void Clear() {
168     for (NodeViewMatch& child : children) {
169       child.Clear();  // child is an object.
170     }
171     children.clear();  // children is a vector.
172     if (node_view != nullptr) {
173       node_view = nullptr;
174     }
175   }
176 };
177 
178 template <MatchingDirection DIRECTION = MatchingDirection::kFollowInputs>
179 class SubGraphMatcher {
180  public:
SubGraphMatcher(MutableGraphView * graph_view)181   SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){};
182 
183   // If a given pattern is matched, this function returns true as well as the
184   // matched node and remove node info is populated.
185   bool GetMatchedNodes(const OpTypePattern& pattern,
186                        const std::unordered_set<string>& nodes_to_preserve,
187                        MutableNodeView* node_view,
188                        std::map<string, int>* matched_nodes_map,
189                        std::set<int>* remove_node_indices);
190 
191  private:
192   MutableGraphView* graph_view_;
193   std::map<string, int> node_label_to_index_;
194   std::set<int> matched_node_indices_;
195   std::set<int> remove_node_indices_;
196   std::unique_ptr<NodeViewMatch> match_ = nullptr;
197 
198   bool DoesOpTypePatternMatch(const OpTypePattern& pattern,
199                               MutableNodeView* node_view, NodeViewMatch* match);
200 
201   // This function should be called after the pattern matcher has found
202   // potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true").
203   // It performs a sanity check if the candidate nodes for removal in subgraph
204   // fusion is indeed safe to remove.
IsSafeNodesToRemove(const std::unordered_set<string> & nodes_to_preserve)205   bool IsSafeNodesToRemove(
206       const std::unordered_set<string>& nodes_to_preserve) {
207     for (const auto& node_idx : remove_node_indices_) {
208       auto node_view = graph_view_->GetNode(node_idx);
209       // Check if the node to be removed is in the nodes to be preserved.
210       string node_name = node_view->GetName();
211       if (nodes_to_preserve.count(node_name) > 0) return false;
212       // Traverse all the Regular Fanouts. Fanouts are stored as vector of
213       // vector, std::vector<std::vector<MutableFaninView>>. Note that
214       // a MutableNodeView's fanouts are stored in a nested vector of
215       // MutableFaninView type.
216       auto fanouts_by_ports = node_view->GetRegularFanouts();
217       for (const auto& fanouts : fanouts_by_ports) {
218         for (const auto& fanout : fanouts) {
219           if (!matched_node_indices_.count(fanout.node_index())) {
220             return false;
221           }
222         }
223       }
224     }
225     return true;
226   }
227 };
228 
229 }  // namespace utils
230 }  // namespace grappler
231 }  // namespace tensorflow
232 
233 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
234