1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities that match patterns in a tf.Graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import itertools 23 24import six 25 26 27@six.add_metaclass(abc.ABCMeta) 28class Pattern(object): 29 """The parent class of all patterns (e.g. OpTypePattern and OneofPattern).""" 30 31 @abc.abstractmethod 32 def match(self, op, tensor): 33 """Returns the result of matching op/tensor against this pattern.""" 34 raise NotImplementedError('Method "match" not implemented.') 35 36 37class OpTypePattern(Pattern): 38 """A tree pattern that matches TF expressions with certain op types.""" 39 40 def __init__(self, op_type, name=None, inputs=None, ordered_inputs=True): 41 """Initializes an OpTypePattern. 42 43 Args: 44 op_type: string that specifies the allowed types of the root. It can be 45 (1) an op type, e.g. 'Conv2D', 46 (2) '*', i.e. wildcard, or 47 (3) multiple op types separated by '|', e.g., 'Relu|Relu6'. 48 We could use regex strings, which might be worthwhile when we have many 49 similar TF op types. 50 name: Optional string. The name of the pattern that can be looked up in 51 MatchResult. 52 inputs: Optional list of `Pattern`s or strings that specify the 53 patterns for the inputs of a matching op. If None, this pattern accepts 54 any inputs of a matching op. 55 ordered_inputs: Defaults to True. If False, will match any op that 56 matches a permutation of the inputs. 57 58 Raises: 59 ValueError: if too many inputs are provided when order_inputs is False. 60 """ 61 self._op_type = op_type 62 self._name = name 63 if inputs is None: 64 inputs = [] 65 if len(inputs) > 8: 66 raise ValueError( 67 'Only < 8 inputs are allowed when ordered_inputs is False.') 68 self._inputs = [ 69 input_pattern 70 if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern) 71 for input_pattern in inputs 72 ] 73 self._ordered_inputs = ordered_inputs 74 75 @property 76 def name(self): 77 return self._name 78 79 def match(self, op, tensor): 80 if self._op_type != '*': 81 if op.type not in self._op_type.split('|'): 82 return None 83 84 match_result = MatchResult() 85 match_result.add(self, op, tensor) 86 87 if not self._inputs: 88 # If pattern.inputs is empty, skips the rest and accepts all the inputs. 89 return match_result 90 91 if len(op.inputs) != len(self._inputs): 92 return None 93 94 input_patterns_list = [self._inputs] 95 # If order doesn't matter for the inputs, then make sure we match at least 96 # one permutation of the inputs. 97 if not self._ordered_inputs: 98 input_patterns_list = list(itertools.permutations(self._inputs)) 99 100 for input_patterns in input_patterns_list: 101 match_failed = False 102 for input_tensor, input_pattern in zip(op.inputs, input_patterns): 103 input_match_result = input_pattern.match(input_tensor.op, input_tensor) 104 if input_match_result is None: 105 match_failed = True 106 break 107 match_result.merge_from(input_match_result) 108 if not match_failed: 109 return match_result 110 return None 111 112 113class OneofPattern(Pattern): 114 """Matches one of the given sub-patterns.""" 115 116 def __init__(self, sub_patterns): 117 self._sub_patterns = sub_patterns 118 119 def match(self, op, tensor): 120 for sub_pattern in self._sub_patterns: 121 match_result = sub_pattern.match(op, tensor) 122 if match_result is not None: 123 return match_result 124 return None 125 126 127class MatchResult(object): 128 r"""Encapsulates the result of a match done by GraphMatcher. 129 130 MatchResult contains a map from Pattern to the matching op and tensor. 131 When the matching op has multiple output tensors, the matching tensor is the 132 output tensor used by the matching op of the parent pattern. E.g., when we 133 match graph 134 135 - + 136 / \y0 y1/ \ 137 x split z 138 | 139 y (nodes are ops; edges are going up) 140 141 against add_pattern defined as 142 143 y1_pattern = OpTypePattern('*') 144 z_pattern = OpTypePattern('*') 145 add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern]) 146 147 the matching op of `y1_pattern` is `split`, and the matching tensor of 148 `y1_pattern` 149 is `y1` not `y0`. 150 """ 151 152 def __init__(self): 153 self._pattern_to_op_tensor = {} 154 self._name_to_pattern = {} 155 156 def add(self, pattern, op, tensor): 157 self._pattern_to_op_tensor[pattern] = op, tensor 158 if pattern.name is not None: 159 if pattern.name in self._name_to_pattern: 160 raise ValueError( 161 'Name %s is already bound to another pattern' % pattern.name) 162 self._name_to_pattern[pattern.name] = pattern 163 164 def _to_pattern(self, pattern_or_name): 165 if isinstance(pattern_or_name, Pattern): 166 return pattern_or_name 167 168 if isinstance(pattern_or_name, str): 169 if pattern_or_name not in self._name_to_pattern: 170 return None 171 return self._name_to_pattern[pattern_or_name] 172 173 raise ValueError('pattern_or_name has type %s. Expect Pattern or str.' % 174 type(pattern_or_name)) 175 176 def _get_op_tensor(self, pattern_or_name): 177 pattern = self._to_pattern(pattern_or_name) 178 if pattern is None: 179 return None 180 181 if pattern not in self._pattern_to_op_tensor: 182 return None 183 184 return self._pattern_to_op_tensor[pattern] 185 186 def get_op(self, pattern_or_name): 187 op_tensor = self._get_op_tensor(pattern_or_name) 188 return op_tensor[0] if op_tensor else None 189 190 def get_tensor(self, pattern_or_name): 191 op_tensor = self._get_op_tensor(pattern_or_name) 192 return op_tensor[1] if op_tensor else None 193 194 def merge_from(self, other_match_result): 195 # pylint: disable=protected-access 196 self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor) 197 self._name_to_pattern.update(other_match_result._name_to_pattern) 198 # pylint: enable=protected-access 199 200 201class GraphMatcher(object): 202 """Checks if a particular subgraph matches a given pattern.""" 203 204 def __init__(self, pattern): 205 """Initializes a GraphMatcher. 206 207 Args: 208 pattern: The `Pattern` against which `GraphMatcher` matches 209 subgraphs. 210 """ 211 self._pattern = pattern 212 213 def _match_pattern(self, pattern, op, tensor): 214 """Returns whether an TF expression rooted at `op` matches `pattern`. 215 216 If there is a match, adds to `self._match_result` the matching op and tensor 217 with key `pattern`. 218 219 Args: 220 pattern: An `Pattern`. 221 op: A `tf.Operation` to match against the pattern. 222 tensor: the output `tf.Tensor` of `op` that is used by the matching op of 223 `pattern`'s parent. Can be None if `pattern` is already the root of the 224 pattern tree. 225 226 Returns: 227 True if an TF expression rooted at `op` matches `pattern`. 228 """ 229 match_result = pattern.match(op, tensor) 230 if match_result is None: 231 return False 232 self._match_result.merge_from(match_result) 233 return True 234 235 def match_op(self, op): 236 """Matches `op` against `self._pattern`. 237 238 Args: 239 op: `tf.Operation` to match against the pattern. 240 241 Returns: 242 Returns a `MatchResult` if `op` matches the pattern; otherwise, returns 243 None. 244 """ 245 self._match_result = MatchResult() 246 if not self._match_pattern(self._pattern, op, tensor=None): 247 return None 248 return self._match_result 249 250 def match_ops(self, ops): 251 """Matches each operation in `ops` against `self._pattern`. 252 253 Args: 254 ops: collection of `tf.Operation` to match against the pattern. 255 256 Yields: 257 `MatchResult` for each `tf.Operation` that matches the pattern. 258 """ 259 for op in ops: 260 match_result = self.match_op(op) 261 if match_result: 262 yield match_result 263 264 def match_graph(self, graph): 265 """Matches each operation in `graph` against `self._pattern`. 266 267 Args: 268 graph: `tf.Graph` containing operations to match. 269 270 Yields: 271 `MatchResult` for each `tf.Operation` in `graph` that matches the pattern. 272 """ 273 # Python 3.3.2+ implements `yield from`, but for now: 274 for match_result in self.match_ops(graph.get_operations()): 275 yield match_result 276