1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/pattern_utils.h"
17 
18 #include <algorithm>
19 
20 #include "absl/container/flat_hash_set.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 namespace utils {
25 
IsCommutativeOp(const string & op)26 const bool IsCommutativeOp(const string& op) {
27   // TODO(intel-tf): Add more ops to this list if needed.
28   std::vector<string> op_list = str_util::Split(op, '|');
29   static const auto* commutative_ops = new absl::flat_hash_set<string>(
30       {"Add", "AddV2", "Mul", "Maximum", "SquaredDifference"});
31   for (const string& op_ : op_list) {
32     if (commutative_ops->contains(op_)) return true;
33   }
34   return false;
35 }
36 
37 // op1 is an op name in the pattern and it could be wildcard `*` or some
38 // registered op in tensorflow and may have multiple ops separated by '|'.
39 // op2 is an op name in the computation graph and
40 // is always one of the registered ops in tensorflow.
IsSame(string op1,string op2)41 bool IsSame(string op1, string op2) {
42   if (op1 == "*") return true;
43 
44   std::vector<string> op1_list = str_util::Split(op1, '|');
45   for (const string& op_1 : op1_list) {
46     if (op_1 == op2) return true;
47   }
48 
49   return false;
50 }
51 
52 // A subgraph pattern syntax implicitly defines a DAG having a single root. We
53 // traverse the syntax DAG in DFS manner. This function finds a match for
54 // current root of the pattern with the current node and recursively matches
55 // children subpatterns with the children of current node.
56 template <>
DoesOpTypePatternMatch(const OpTypePattern & pattern,MutableNodeView * node_view,NodeViewMatch * match)57 bool SubGraphMatcher<MatchingDirection::kFollowInputs>::DoesOpTypePatternMatch(
58     const OpTypePattern& pattern, MutableNodeView* node_view,
59     NodeViewMatch* match) {
60   // Currently no control inputs and outputs are allowed.
61   if (node_view->NumControllingFanins() > 0 ||
62       node_view->NumControlledFanouts() > 0)
63     return false;
64 
65   bool op_type_matched = false;
66   if (pattern.op == "*") {
67     op_type_matched = true;
68   } else {
69     // The op field string of current pattern might express an op among multiple
70     // op types (mutually exclusive) separated by '|'.
71     std::vector<string> op_list = str_util::Split(pattern.op, '|');
72     for (const string& op : op_list) {
73       if (node_view->node()->op() == op) {
74         op_type_matched = true;
75         break;
76       }
77     }
78   }
79   if (op_type_matched) {
80     // If op type matches and current node is visited first time, insert current
81     // node to node_label_to_index_ map with the current label as the key.
82     // Multiple occurances of same label in the pattern syntax indicates that
83     // the same node needs to be visited for each of such occurances. Hence
84     // subsequent visits should find the corresponding label in the map as a key
85     // and the current node should be the value for that key.
86     if (node_label_to_index_.find(pattern.label) ==
87         node_label_to_index_.end()) {
88       node_label_to_index_[pattern.label] = node_view->node_index();
89       // Bookkeeping
90       matched_node_indices_.insert(node_view->node_index());
91       if (pattern.node_status == NodeStatus::kRemove) {
92         remove_node_indices_.insert(node_view->node_index());
93       }
94     } else if (node_label_to_index_[pattern.label] != node_view->node_index()) {
95       return false;  // label constraint could not be satisfied.
96     } else {
97       DCHECK(node_label_to_index_[pattern.label] == node_view->node_index());
98     }
99   } else {
100     return false;
101   }
102   // Current root of the pattern syntax is matched with the current node.
103   match->node_view = node_view;
104 
105   // Go for matching child subpattern.
106   if (!pattern.children.empty()) {
107     // Currently only direction toward inputs is implemented.
108     auto graph_children = node_view->GetRegularFanins();
109     int num_children = graph_children.size();
110     if (num_children != pattern.children.size()) {
111       return false;
112     } else {
113       // A pattern is a graph that we would like to match with a subgraph of
114       // a tensorflow computation graph. We travese both pattern-graph and the
115       // given graph in DFS manner and try to find one-to-one mapping between
116       // the nodes. However, commutative binary ops (e.g., Add, AddV2, Mul
117       // etc.) in the computation graph can have their inputs in different order
118       // than the pattern syntax graph. To allow such input permutation in a
119       // limited manner, we employ a heuristic of looking one level ahead in
120       // both graphs, whether visiting the right child of pattern is likely to
121       // match left child of the given graph. In that case, we simply swap the
122       // left subtree with right subtree of pattern syntax graph and continue
123       // matching children of pattern with the children of given computation
124       // graph. Note, we do not change anything in the computation graph during
125       // pattern matching, only the pattern graph is changed. By looking ahead
126       // one step in case of commutative ops, we keep the time comlexity of
127       // pattern matching linear. Since it is only a heuristic and we look only
128       // one step ahead it is not guranteed that all possible permutations will
129       // be matched. For example, when both the input ops to the commutative op
130       // are same, we cannot anticipate which of the permutation is likely to
131       // match unless we look two level down the graphs.
132       std::vector<int> pattern_child_indices(num_children);
133       std::iota(pattern_child_indices.begin(), pattern_child_indices.end(), 0);
134       string op_name = pattern.op;
135       if (IsCommutativeOp(op_name) && num_children == 2) {
136         MutableNodeView* graph_child0_node_view =
137             graph_view_->GetNode(graph_children[0].node_index());
138         MutableNodeView* graph_child1_node_view =
139             graph_view_->GetNode(graph_children[1].node_index());
140         if ((!IsSame(pattern.children[0].op, graph_child0_node_view->GetOp()) &&
141              IsSame(pattern.children[1].op, graph_child0_node_view->GetOp())) ||
142             (!IsSame(pattern.children[1].op, graph_child1_node_view->GetOp()) &&
143              IsSame(pattern.children[0].op, graph_child1_node_view->GetOp())))
144           std::swap(pattern_child_indices[0], pattern_child_indices[1]);
145       }
146       for (int i = 0; i < num_children; ++i) {
147         auto child_node_index = graph_children[i].node_index();
148         // TODO (mdfaijul): Is it guaranted that GetNode will reuturn non null
149         // pointer.
150         MutableNodeView* child_node_view =
151             graph_view_->GetNode(child_node_index);
152         const OpTypePattern& child_pattern =
153             pattern.children[pattern_child_indices[i]];
154         match->children.push_back(NodeViewMatch());
155         NodeViewMatch* child_match = &(match->children.back());
156         if (!DoesOpTypePatternMatch(child_pattern, child_node_view,
157                                     child_match)) {
158           return false;
159         }
160       }
161     }
162   }
163   return true;
164 }
165 
166 // Current implementation supports pattern maching toward node's inputs only.
167 template <>
GetMatchedNodes(const OpTypePattern & pattern,const std::unordered_set<string> & nodes_to_preserve,MutableNodeView * node_view,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)168 bool SubGraphMatcher<MatchingDirection::kFollowInputs>::GetMatchedNodes(
169     const OpTypePattern& pattern,
170     const std::unordered_set<string>& nodes_to_preserve,
171     MutableNodeView* node_view, std::map<string, int>* matched_nodes_map,
172     std::set<int>* remove_node_indices) {
173   bool found_match = false;
174   match_.reset(new NodeViewMatch());
175   if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) {
176     if (IsSafeNodesToRemove(nodes_to_preserve)) {
177       found_match = true;
178       *matched_nodes_map = this->node_label_to_index_;
179       *remove_node_indices = this->remove_node_indices_;
180     }
181   } else {
182     found_match = false;
183   }
184 
185   // Clear all bookkeeping data
186   match_->Clear();
187   match_.reset(nullptr);
188   matched_node_indices_.clear();
189   node_label_to_index_.clear();
190   remove_node_indices_.clear();
191 
192   return found_match;
193 }
194 
195 }  // namespace utils
196 }  // namespace grappler
197 }  // namespace tensorflow
198