• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/pattern_utils.h"
17 
18 namespace tensorflow {
19 namespace grappler {
20 namespace utils {
21 
22 // A subgraph pattern syntax implicitly defines a DAG having a single root. We
23 // traverse the syntax DAG in DFS manner. This function finds a match for
24 // current root of the pattern with the current node and recursively matches
25 // children subpatterns with the children of current node.
26 template <>
DoesOpTypePatternMatch(const OpTypePattern & pattern,MutableNodeView * node_view,NodeViewMatch * match)27 bool SubGraphMatcher<MatchingDirection::kFollowInputs>::DoesOpTypePatternMatch(
28     const OpTypePattern& pattern, MutableNodeView* node_view,
29     NodeViewMatch* match) {
30   // Currently no control inputs and outputs are allowed.
31   if (node_view->NumControllingFanins() > 0 ||
32       node_view->NumControlledFanouts() > 0)
33     return false;
34 
35   bool op_type_matched = false;
36   if (pattern.op == "*") {
37     op_type_matched = true;
38   } else {
39     // The op field string of current pattern might express an op among multiple
40     // op types (mutually exclusive) separated by '|'.
41     std::vector<string> op_list = str_util::Split(pattern.op, '|');
42     for (const string& op : op_list) {
43       if (node_view->node()->op() == op) {
44         op_type_matched = true;
45         break;
46       }
47     }
48   }
49   if (op_type_matched) {
50     // If op type matches and current node is visited first time, insert current
51     // node to node_label_to_index_ map with the current label as the key.
52     // Multiple occurances of same label in the pattern syntax indicates that
53     // the same node needs to be visited for each of such occurances. Hence
54     // subsequent visits should find the corresponding label in the map as a key
55     // and the current node should be the value for that key.
56     if (node_label_to_index_.find(pattern.label) ==
57         node_label_to_index_.end()) {
58       node_label_to_index_[pattern.label] = node_view->node_index();
59       // Bookkeeping
60       matched_node_indices_.insert(node_view->node_index());
61       if (pattern.node_status == NodeStatus::kRemove) {
62         remove_node_indices_.insert(node_view->node_index());
63       }
64     } else if (node_label_to_index_[pattern.label] != node_view->node_index()) {
65       return false;  // label constraint could not be satisfied.
66     } else {
67       DCHECK(node_label_to_index_[pattern.label] == node_view->node_index());
68     }
69   } else {
70     return false;
71   }
72   // Current root of the pattern syntax is matched with the current node.
73   match->node_view = node_view;
74 
75   // Go for matching child subpattern.
76   if (!pattern.children.empty()) {
77     // Currently only direction toward inputs is implemented.
78     auto node_view_children = node_view->GetRegularFanins();
79     if (node_view_children.size() != pattern.children.size()) {
80       return false;
81     } else {
82       for (int i = 0; i < pattern.children.size(); ++i) {
83         auto child_node_index = node_view_children[i].node_index();
84         // TODO (mdfaijul): Is it guaranted that GetNode will reuturn non null
85         // pointer.
86         MutableNodeView* child_node_view =
87             graph_view_->GetNode(child_node_index);
88         const OpTypePattern& child_pattern = pattern.children[i];
89         match->children.push_back(NodeViewMatch());
90         NodeViewMatch* child_match = &(match->children.back());
91         if (!DoesOpTypePatternMatch(child_pattern, child_node_view,
92                                     child_match)) {
93           return false;
94         }
95       }
96     }
97   }
98   return true;
99 }
100 
101 // Current implementation supports pattern maching toward node's inputs only.
102 template <>
GetMatchedNodes(const OpTypePattern & pattern,MutableNodeView * node_view,std::map<string,int> * matched_nodes_map,std::set<int> * remove_node_indices)103 bool SubGraphMatcher<MatchingDirection::kFollowInputs>::GetMatchedNodes(
104     const OpTypePattern& pattern, MutableNodeView* node_view,
105     std::map<string, int>* matched_nodes_map,
106     std::set<int>* remove_node_indices) {
107   bool found_match = false;
108   match_.reset(new NodeViewMatch());
109   if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) {
110     if (!HasRemoveNodeExternalDependents()) {
111       found_match = true;
112       matched_nodes_map->swap(this->node_label_to_index_);
113       remove_node_indices->swap(this->remove_node_indices_);
114     }
115   } else {
116     found_match = false;
117     // Clear all bookkeeping data
118     match_->Clear();
119     match_.reset(nullptr);
120     node_label_to_index_.clear();
121     matched_node_indices_.clear();
122     remove_node_indices_.clear();
123   }
124   return found_match;
125 }
126 
127 }  // namespace utils
128 }  // namespace grappler
129 }  // namespace tensorflow
130