1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/pattern_utils.h"
17
18 #include <algorithm>
19
20 #include "absl/container/flat_hash_set.h"
21
22 namespace tensorflow {
23 namespace grappler {
24 namespace utils {
25
IsCommutativeOp(const string & op)26 const bool IsCommutativeOp(const string& op) {
27 // TODO(intel-tf): Add more ops to this list if needed.
28 std::vector<string> op_list = str_util::Split(op, '|');
29 static const auto* commutative_ops = new absl::flat_hash_set<string>(
30 {"Add", "AddV2", "Mul", "Maximum", "SquaredDifference"});
31 for (const string& op_ : op_list) {
32 if (commutative_ops->contains(op_)) return true;
33 }
34 return false;
35 }
36
37 // op1 is an op name in the pattern and it could be wildcard `*` or some
38 // registered op in tensorflow and may have multiple ops separated by '|'.
39 // op2 is an op name in the computation graph and
40 // is always one of the registered ops in tensorflow.
IsSame(string op1,string op2)41 bool IsSame(string op1, string op2) {
42 if (op1 == "*") return true;
43
44 std::vector<string> op1_list = str_util::Split(op1, '|');
45 for (const string& op_1 : op1_list) {
46 if (op_1 == op2) return true;
47 }
48
49 return false;
50 }
51
52 // A subgraph pattern syntax implicitly defines a DAG having a single root. We
53 // traverse the syntax DAG in DFS manner. This function finds a match for
54 // current root of the pattern with the current node and recursively matches
55 // children subpatterns with the children of current node.
56 template <>
DoesOpTypePatternMatch(const OpTypePattern & pattern,MutableNodeView * node_view,NodeViewMatch * match)57 bool SubGraphMatcher<MatchingDirection::kFollowInputs>::DoesOpTypePatternMatch(
58 const OpTypePattern& pattern, MutableNodeView* node_view,
59 NodeViewMatch* match) {
60 // Currently no control inputs and outputs are allowed.
61 if (node_view->NumControllingFanins() > 0 ||
62 node_view->NumControlledFanouts() > 0)
63 return false;
64
65 bool op_type_matched = false;
66 if (pattern.op == "*") {
67 op_type_matched = true;
68 } else {
69 // The op field string of current pattern might express an op among multiple
70 // op types (mutually exclusive) separated by '|'.
71 std::vector<string> op_list = str_util::Split(pattern.op, '|');
72 for (const string& op : op_list) {
73 if (node_view->node()->op() == op) {
74 op_type_matched = true;
75 break;
76 }
77 }
78 }
79 if (op_type_matched) {
80 // If op type matches and current node is visited first time, insert current
81 // node to node_label_to_index_ map with the current label as the key.
82 // Multiple occurances of same label in the pattern syntax indicates that
83 // the same node needs to be visited for each of such occurances. Hence
84 // subsequent visits should find the corresponding label in the map as a key
85 // and the current node should be the value for that key.
86 if (node_label_to_index_.find(pattern.label) ==
87 node_label_to_index_.end()) {
88 node_label_to_index_[pattern.label] = node_view->node_index();
89 // Bookkeeping
90 matched_node_indices_.insert(node_view->node_index());
91 if (pattern.node_status == NodeStatus::kRemove) {
92 remove_node_indices_.insert(node_view->node_index());
93 }
94 } else if (node_label_to_index_[pattern.label] != node_view->node_index()) {
95 return false; // label constraint could not be satisfied.
96 } else {
97 DCHECK(node_label_to_index_[pattern.label] == node_view->node_index());
98 }
99 } else {
100 return false;
101 }
102 // Current root of the pattern syntax is matched with the current node.
103 match->node_view = node_view;
104
105 // Go for matching child subpattern.
106 if (!pattern.children.empty()) {
107 // Currently only direction toward inputs is implemented.
108 auto graph_children = node_view->GetRegularFanins();
109 int num_children = graph_children.size();
110 if (num_children != pattern.children.size()) {
111 return false;
112 } else {
113 // A pattern is a graph that we would like to match with a subgraph of
114 // a tensorflow computation graph. We travese both pattern-graph and the
115 // given graph in DFS manner and try to find one-to-one mapping between
116 // the nodes. However, commutative binary ops (e.g., Add, AddV2, Mul
117 // etc.) in the computation graph can have their inputs in different order
118 // than the pattern syntax graph. To allow such input permutation in a
119 // limited manner, we employ a heuristic of looking one level ahead in
120 // both graphs, whether visiting the right child of pattern is likely to
121 // match left child of the given graph. In that case, we simply swap the
122 // left subtree with right subtree of pattern syntax graph and continue
123 // matching children of pattern with the children of given computation
124 // graph. Note, we do not change anything in the computation graph during
125 // pattern matching, only the pattern graph is changed. By looking ahead
126 // one step in case of commutative ops, we keep the time comlexity of
127 // pattern matching linear. Since it is only a heuristic and we look only
128 // one step ahead it is not guranteed that all possible permutations will
129 // be matched. For example, when both the input ops to the commutative op
130 // are same, we cannot anticipate which of the permutation is likely to
131 // match unless we look two level down the graphs.
132 std::vector<int> pattern_child_indices(num_children);
133 std::iota(pattern_child_indices.begin(), pattern_child_indices.end(), 0);
134 string op_name = pattern.op;
135 if (IsCommutativeOp(op_name) && num_children == 2) {
136 MutableNodeView* graph_child0_node_view =
137 graph_view_->GetNode(graph_children[0].node_index());
138 MutableNodeView* graph_child1_node_view =
139 graph_view_->GetNode(graph_children[1].node_index());
140 if ((!IsSame(pattern.children[0].op, graph_child0_node_view->GetOp()) &&
141 IsSame(pattern.children[1].op, graph_child0_node_view->GetOp())) ||
142 (!IsSame(pattern.children[1].op, graph_child1_node_view->GetOp()) &&
143 IsSame(pattern.children[0].op, graph_child1_node_view->GetOp())))
144 std::swap(pattern_child_indices[0], pattern_child_indices[1]);
145 }
146 for (int i = 0; i < num_children; ++i) {
147 auto child_node_index = graph_children[i].node_index();
148 // TODO (mdfaijul): Is it guaranted that GetNode will reuturn non null
149 // pointer.
150 MutableNodeView* child_node_view =
151 graph_view_->GetNode(child_node_index);
152 const OpTypePattern& child_pattern =
153 pattern.children[pattern_child_indices[i]];
154 match->children.push_back(NodeViewMatch());
155 NodeViewMatch* child_match = &(match->children.back());
156 if (!DoesOpTypePatternMatch(child_pattern, child_node_view,
157 child_match)) {
158 return false;
159 }
160 }
161 }
162 }
163 return true;
164 }
165
166 // Current implementation supports pattern maching toward node's inputs only.
167 template <>
GetMatchedNodes(const OpTypePattern & pattern,const std::unordered_set<string> & nodes_to_preserve,MutableNodeView * node_view,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)168 bool SubGraphMatcher<MatchingDirection::kFollowInputs>::GetMatchedNodes(
169 const OpTypePattern& pattern,
170 const std::unordered_set<string>& nodes_to_preserve,
171 MutableNodeView* node_view, std::map<string, int>* matched_nodes_map,
172 std::set<int>* remove_node_indices) {
173 bool found_match = false;
174 match_.reset(new NodeViewMatch());
175 if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) {
176 if (IsSafeNodesToRemove(nodes_to_preserve)) {
177 found_match = true;
178 *matched_nodes_map = this->node_label_to_index_;
179 *remove_node_indices = this->remove_node_indices_;
180 }
181 } else {
182 found_match = false;
183 }
184
185 // Clear all bookkeeping data
186 match_->Clear();
187 match_.reset(nullptr);
188 matched_node_indices_.clear();
189 node_label_to_index_.clear();
190 remove_node_indices_.clear();
191
192 return found_match;
193 }
194
195 } // namespace utils
196 } // namespace grappler
197 } // namespace tensorflow
198