1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for selecting ops in a graph.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.util import object_identity 19 20 21def is_differentiable(op): 22 try: 23 return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access 24 except LookupError: 25 return False 26 27 28def is_iterable(obj): 29 """Return true if the object is iterable.""" 30 if isinstance(obj, ops.Tensor): 31 return False 32 try: 33 _ = iter(obj) 34 except Exception: # pylint: disable=broad-except 35 return False 36 return True 37 38 39def concatenate_unique(la, lb): 40 """Add all the elements of `lb` to `la` if they are not there already. 41 42 The elements added to `la` maintain ordering with respect to `lb`. 43 44 Args: 45 la: List of Python objects. 46 lb: List of Python objects. 47 Returns: 48 `la`: The list `la` with missing elements from `lb`. 49 """ 50 la_set = set(la) 51 for l in lb: 52 if l not in la_set: 53 la.append(l) 54 la_set.add(l) 55 return la 56 57 58def get_tensors(graph): 59 """get all the tensors which are input or output of an op in the graph. 60 61 Args: 62 graph: a `tf.Graph`. 63 Returns: 64 A list of `tf.Tensor`. 65 Raises: 66 TypeError: if graph is not a `tf.Graph`. 67 """ 68 if not isinstance(graph, ops.Graph): 69 raise TypeError("Expected a graph, got: {}".format(type(graph))) 70 ts = [] 71 for op in graph.get_operations(): 72 ts += op.outputs 73 return ts 74 75 76def get_unique_graph(tops, check_types=None, none_if_empty=False): 77 """Return the unique graph used by the all the elements in tops. 78 79 Args: 80 tops: iterable of elements to check (usually a list of tf.Operation and/or 81 tf.Tensor). Or a tf.Graph. 82 check_types: check that the element in tops are of given type(s). If None, 83 the types (tf.Operation, tf.Tensor) are used. 84 none_if_empty: don't raise an error if tops is an empty list, just return 85 None. 86 Returns: 87 The unique graph used by all the tops. 88 Raises: 89 TypeError: if tops is not a iterable of tf.Operation. 90 ValueError: if the graph is not unique. 91 """ 92 if isinstance(tops, ops.Graph): 93 return tops 94 if not is_iterable(tops): 95 raise TypeError("{} is not iterable".format(type(tops))) 96 if check_types is None: 97 check_types = (ops.Operation, ops.Tensor) 98 elif not is_iterable(check_types): 99 check_types = (check_types,) 100 g = None 101 for op in tops: 102 if not isinstance(op, check_types): 103 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( 104 t) for t in check_types]), type(op))) 105 if g is None: 106 g = op.graph 107 elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access 108 raise ValueError("Operation {} does not belong to given graph".format(op)) 109 if g is None and not none_if_empty: 110 raise ValueError("Can't find the unique graph of an empty list") 111 return g 112 113 114def check_graphs(*args): 115 """Check that all the element in args belong to the same graph. 116 117 Args: 118 *args: a list of object with a obj.graph property. 119 Raises: 120 ValueError: if all the elements do not belong to the same graph. 121 """ 122 graph = None 123 for i, sgv in enumerate(args): 124 if graph is None and sgv.graph is not None: 125 graph = sgv.graph 126 elif sgv.graph is not None and sgv.graph is not graph: 127 raise ValueError(f"args[{i}] does not belong to the same graph as " 128 "other arguments.") 129 130 131def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): 132 """Convert ts to a list of `tf.Tensor`. 133 134 Args: 135 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. 136 check_graph: if `True` check if all the tensors belong to the same graph. 137 allow_graph: if `False` a `tf.Graph` cannot be converted. 138 ignore_ops: if `True`, silently ignore `tf.Operation`. 139 Returns: 140 A newly created list of `tf.Tensor`. 141 Raises: 142 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, 143 if `check_graph` is `True`, if all the ops do not belong to the same graph. 144 """ 145 if isinstance(ts, ops.Graph): 146 if allow_graph: 147 return get_tensors(ts) 148 else: 149 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 150 else: 151 if not is_iterable(ts): 152 ts = [ts] 153 if not ts: 154 return [] 155 if check_graph: 156 check_types = None if ignore_ops else ops.Tensor 157 get_unique_graph(ts, check_types=check_types) 158 return [t for t in ts if isinstance(t, ops.Tensor)] 159 160 161def get_generating_ops(ts): 162 """Return all the generating ops of the tensors in `ts`. 163 164 Args: 165 ts: a list of `tf.Tensor` 166 Returns: 167 A list of all the generating `tf.Operation` of the tensors in `ts`. 168 Raises: 169 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. 170 """ 171 ts = make_list_of_t(ts, allow_graph=False) 172 return [t.op for t in ts] 173 174 175def get_consuming_ops(ts): 176 """Return all the consuming ops of the tensors in ts. 177 178 Args: 179 ts: a list of `tf.Tensor` 180 Returns: 181 A list of all the consuming `tf.Operation` of the tensors in `ts`. 182 Raises: 183 TypeError: if ts cannot be converted to a list of `tf.Tensor`. 184 """ 185 ts = make_list_of_t(ts, allow_graph=False) 186 tops = [] 187 for t in ts: 188 for op in t.consumers(): 189 if op not in tops: 190 tops.append(op) 191 return tops 192 193 194def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False): 195 """Convert ops to a list of `tf.Operation`. 196 197 Args: 198 tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single 199 operation. 200 check_graph: if `True` check if all the operations belong to the same graph. 201 allow_graph: if `False` a `tf.Graph` cannot be converted. 202 ignore_ts: if True, silently ignore `tf.Tensor`. 203 Returns: 204 A newly created list of `tf.Operation`. 205 Raises: 206 TypeError: if tops cannot be converted to a list of `tf.Operation` or, 207 if `check_graph` is `True`, if all the ops do not belong to the 208 same graph. 209 """ 210 if isinstance(tops, ops.Graph): 211 if allow_graph: 212 return tops.get_operations() 213 else: 214 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 215 else: 216 if not is_iterable(tops): 217 tops = [tops] 218 if not tops: 219 return [] 220 if check_graph: 221 check_types = None if ignore_ts else ops.Operation 222 get_unique_graph(tops, check_types=check_types) 223 return [op for op in tops if isinstance(op, ops.Operation)] 224 225 226def _get_inputs(op, only_differentiable): 227 op_inputs = op.inputs 228 if only_differentiable: 229 return op_inputs if is_differentiable(op) else [] 230 else: 231 return op_inputs 232 233 234def get_backward_walk_ops(seed_ops, 235 inclusive=True, 236 within_ops=None, 237 within_ops_fn=None, 238 stop_at_ts=(), 239 control_inputs=False, 240 only_differentiable=False): 241 """Do a backward graph walk and return all the visited ops. 242 243 Args: 244 seed_ops: an iterable of operations from which the backward graph 245 walk starts. If a list of tensors is given instead, the seed_ops are set 246 to be the generators of those tensors. 247 inclusive: if True the given seed_ops are also part of the resulting set. 248 within_ops: an iterable of `tf.Operation` within which the search is 249 restricted. If `within_ops` is `None`, the search is performed within 250 the whole graph. 251 within_ops_fn: if provided, a function on ops that should return True iff 252 the op is within the graph traversal. This can be used along within_ops, 253 in which case an op is within if it is also in within_ops. 254 stop_at_ts: an iterable of tensors at which the graph walk stops. 255 control_inputs: if True, control inputs will be used while moving backward. 256 only_differentiable: if True, only traverse ops which are differentiable. 257 This includes natively differentiable ops, or ops with custom gradients. 258 Returns: 259 A Python set of all the `tf.Operation` behind `seed_ops`. 260 Raises: 261 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of 262 `tf.Operation`. 263 """ 264 control_inputs = control_inputs and (not only_differentiable) 265 266 if not is_iterable(seed_ops): 267 seed_ops = [seed_ops] 268 269 try: 270 first_seed_op = next(iter(seed_ops)) 271 except StopIteration: 272 # Empty iterable. 273 return [] 274 275 if isinstance(first_seed_op, ops.Tensor): 276 ts = make_list_of_t(seed_ops, allow_graph=False) 277 seed_ops = get_generating_ops(ts) 278 else: 279 seed_ops = make_list_of_op(seed_ops, allow_graph=False) 280 281 stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts)) 282 seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops)) 283 if within_ops: 284 within_ops = make_list_of_op(within_ops, allow_graph=False) 285 within_ops = object_identity.ObjectIdentitySet(within_ops) 286 seed_ops &= within_ops 287 288 def is_within(op): 289 return (within_ops is None or op in within_ops) and ( 290 within_ops_fn is None or within_ops_fn(op)) 291 292 result = list(seed_ops) 293 wave = set(seed_ops) 294 while wave: 295 new_wave = set() 296 for op in wave: 297 for new_t in _get_inputs(op, only_differentiable=only_differentiable): 298 if new_t in stop_at_ts: 299 continue 300 if new_t.op not in result and is_within(new_t.op): 301 new_wave.add(new_t.op) 302 if control_inputs: 303 for new_op in op.control_inputs: 304 if new_op not in result and is_within(new_op): 305 new_wave.add(new_op) 306 concatenate_unique(result, new_wave) 307 wave = new_wave 308 if not inclusive: 309 result = [op for op in result if op not in seed_ops] 310 return result 311 312 313class UnliftableError(Exception): 314 """Raised if a Tensor cannot be lifted from the graph.""" 315 316 # Prevent autograph from rewriting this error. 317 ag_pass_through = True 318 319 320def _as_operation(op_or_tensor): 321 if isinstance(op_or_tensor, ops.Tensor): 322 return op_or_tensor.op 323 return op_or_tensor 324 325 326def graph_inputs(op): 327 return [x.op for x in op.inputs] + list(op.control_inputs) 328 329 330def show_path(from_op, tensors, sources): 331 """Find one path from `from_op` to any of `tensors`, ignoring `sources`. 332 333 Args: 334 from_op: A `tf.Operation`. 335 tensors: A `tf.Operation`, a `tf.Tensor`, or a list thereof. 336 sources: A list of `tf.Tensor`. 337 338 Returns: 339 A python string containing the path, or "??" if none is found. 340 """ 341 if isinstance(from_op, ops.Tensor): 342 from_op = from_op.op 343 344 if not isinstance(tensors, list): 345 tensors = [tensors] 346 347 final_ops = [_as_operation(tensor) for tensor in tensors] 348 349 visited_ops = set(x.op for x in sources) 350 ops_to_visit = list(final_ops) 351 some_op_output = {} 352 while ops_to_visit: 353 op = ops_to_visit.pop() 354 if op in visited_ops: 355 continue 356 visited_ops.add(op) 357 if op == from_op: 358 path_op = op 359 path = [path_op] 360 while path_op not in final_ops: 361 path_op = some_op_output[path_op] 362 path.append(path_op) 363 return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path)) 364 else: 365 for inp in graph_inputs(op): 366 if inp not in visited_ops and inp not in sources: 367 some_op_output[inp] = op 368 ops_to_visit.append(inp) 369 return "??" 370 371 372# TODO(jmenick) - there is considerable duplication of functionality between 373# this function and get_backward_walk_ops(). Need to deduplicate. 374def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, 375 op_outputs, add_sources): 376 """Walk a Graph and capture the subgraph between init_tensor and sources. 377 378 Note: This function mutates visited_ops and op_outputs. 379 380 Args: 381 init_tensor: A Tensor or Operation where the subgraph terminates. 382 sources: A set of Tensors where subgraph extraction should stop. 383 disallowed_placeholders: An optional set of ops which may not appear in the 384 lifted graph. Defaults to all placeholders. 385 visited_ops: A set of operations which were visited in a prior pass. 386 op_outputs: A defaultdict containing the outputs of an op which are to be 387 copied into the new subgraph. 388 add_sources: A boolean indicating whether placeholders which are not in 389 sources should be allowed. 390 391 Returns: 392 The set of placeholders upon which init_tensor depends and are not in 393 sources. 394 395 Raises: 396 UnliftableError: if init_tensor depends on a placeholder which is not in 397 sources and add_sources is False. 398 """ 399 ops_to_visit = [_as_operation(init_tensor)] 400 extra_sources = object_identity.ObjectIdentitySet() 401 while ops_to_visit: 402 op = ops_to_visit.pop() 403 if op in visited_ops: 404 continue 405 visited_ops.add(op) 406 407 should_raise = False 408 if disallowed_placeholders is not None and op in disallowed_placeholders: 409 should_raise = True 410 elif op.type == "Placeholder": 411 if disallowed_placeholders is None and not add_sources: 412 should_raise = True 413 extra_sources.update(op.outputs) 414 415 if should_raise: 416 raise UnliftableError( 417 "Unable to lift tensor %s because it depends transitively on " 418 "placeholder %s via at least one path, e.g.: %s" % 419 (repr(init_tensor), repr(op), show_path(op, init_tensor, sources))) 420 for inp in graph_inputs(op): 421 op_outputs[inp].add(op) 422 if inp not in visited_ops and inp not in (sources or extra_sources): 423 ops_to_visit.append(inp) 424 425 return extra_sources 426