1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for selecting ops in a graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.util import object_identity 23 24 25def is_differentiable(op): 26 try: 27 return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access 28 except LookupError: 29 return False 30 31 32def is_iterable(obj): 33 """Return true if the object is iterable.""" 34 if isinstance(obj, ops.Tensor): 35 return False 36 try: 37 _ = iter(obj) 38 except Exception: # pylint: disable=broad-except 39 return False 40 return True 41 42 43def concatenate_unique(la, lb): 44 """Add all the elements of `lb` to `la` if they are not there already. 45 46 The elements added to `la` maintain ordering with respect to `lb`. 47 48 Args: 49 la: List of Python objects. 50 lb: List of Python objects. 51 Returns: 52 `la`: The list `la` with missing elements from `lb`. 53 """ 54 la_set = set(la) 55 for l in lb: 56 if l not in la_set: 57 la.append(l) 58 la_set.add(l) 59 return la 60 61 62def get_tensors(graph): 63 """get all the tensors which are input or output of an op in the graph. 64 65 Args: 66 graph: a `tf.Graph`. 67 Returns: 68 A list of `tf.Tensor`. 69 Raises: 70 TypeError: if graph is not a `tf.Graph`. 71 """ 72 if not isinstance(graph, ops.Graph): 73 raise TypeError("Expected a graph, got: {}".format(type(graph))) 74 ts = [] 75 for op in graph.get_operations(): 76 ts += op.outputs 77 return ts 78 79 80def get_unique_graph(tops, check_types=None, none_if_empty=False): 81 """Return the unique graph used by the all the elements in tops. 82 83 Args: 84 tops: list of elements to check (usually a list of tf.Operation and/or 85 tf.Tensor). Or a tf.Graph. 86 check_types: check that the element in tops are of given type(s). If None, 87 the types (tf.Operation, tf.Tensor) are used. 88 none_if_empty: don't raise an error if tops is an empty list, just return 89 None. 90 Returns: 91 The unique graph used by all the tops. 92 Raises: 93 TypeError: if tops is not a iterable of tf.Operation. 94 ValueError: if the graph is not unique. 95 """ 96 if isinstance(tops, ops.Graph): 97 return tops 98 if not is_iterable(tops): 99 raise TypeError("{} is not iterable".format(type(tops))) 100 if check_types is None: 101 check_types = (ops.Operation, ops.Tensor) 102 elif not is_iterable(check_types): 103 check_types = (check_types,) 104 g = None 105 for op in tops: 106 if not isinstance(op, check_types): 107 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( 108 t) for t in check_types]), type(op))) 109 if g is None: 110 g = op.graph 111 elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access 112 raise ValueError("Operation {} does not belong to given graph".format(op)) 113 if g is None and not none_if_empty: 114 raise ValueError("Can't find the unique graph of an empty list") 115 return g 116 117 118def check_graphs(*args): 119 """Check that all the element in args belong to the same graph. 120 121 Args: 122 *args: a list of object with a obj.graph property. 123 Raises: 124 ValueError: if all the elements do not belong to the same graph. 125 """ 126 graph = None 127 for i, sgv in enumerate(args): 128 if graph is None and sgv.graph is not None: 129 graph = sgv.graph 130 elif sgv.graph is not None and sgv.graph is not graph: 131 raise ValueError(f"args[{i}] does not belong to the same graph as " 132 "other arguments.") 133 134 135def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): 136 """Convert ts to a list of `tf.Tensor`. 137 138 Args: 139 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. 140 check_graph: if `True` check if all the tensors belong to the same graph. 141 allow_graph: if `False` a `tf.Graph` cannot be converted. 142 ignore_ops: if `True`, silently ignore `tf.Operation`. 143 Returns: 144 A newly created list of `tf.Tensor`. 145 Raises: 146 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, 147 if `check_graph` is `True`, if all the ops do not belong to the same graph. 148 """ 149 if isinstance(ts, ops.Graph): 150 if allow_graph: 151 return get_tensors(ts) 152 else: 153 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 154 else: 155 if not is_iterable(ts): 156 ts = [ts] 157 if not ts: 158 return [] 159 if check_graph: 160 check_types = None if ignore_ops else ops.Tensor 161 get_unique_graph(ts, check_types=check_types) 162 return [t for t in ts if isinstance(t, ops.Tensor)] 163 164 165def get_generating_ops(ts): 166 """Return all the generating ops of the tensors in `ts`. 167 168 Args: 169 ts: a list of `tf.Tensor` 170 Returns: 171 A list of all the generating `tf.Operation` of the tensors in `ts`. 172 Raises: 173 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. 174 """ 175 ts = make_list_of_t(ts, allow_graph=False) 176 return [t.op for t in ts] 177 178 179def get_consuming_ops(ts): 180 """Return all the consuming ops of the tensors in ts. 181 182 Args: 183 ts: a list of `tf.Tensor` 184 Returns: 185 A list of all the consuming `tf.Operation` of the tensors in `ts`. 186 Raises: 187 TypeError: if ts cannot be converted to a list of `tf.Tensor`. 188 """ 189 ts = make_list_of_t(ts, allow_graph=False) 190 tops = [] 191 for t in ts: 192 for op in t.consumers(): 193 if op not in tops: 194 tops.append(op) 195 return tops 196 197 198def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False): 199 """Convert ops to a list of `tf.Operation`. 200 201 Args: 202 tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single 203 operation. 204 check_graph: if `True` check if all the operations belong to the same graph. 205 allow_graph: if `False` a `tf.Graph` cannot be converted. 206 ignore_ts: if True, silently ignore `tf.Tensor`. 207 Returns: 208 A newly created list of `tf.Operation`. 209 Raises: 210 TypeError: if tops cannot be converted to a list of `tf.Operation` or, 211 if `check_graph` is `True`, if all the ops do not belong to the 212 same graph. 213 """ 214 if isinstance(tops, ops.Graph): 215 if allow_graph: 216 return tops.get_operations() 217 else: 218 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 219 else: 220 if not is_iterable(tops): 221 tops = [tops] 222 if not tops: 223 return [] 224 if check_graph: 225 check_types = None if ignore_ts else ops.Operation 226 get_unique_graph(tops, check_types=check_types) 227 return [op for op in tops if isinstance(op, ops.Operation)] 228 229 230def _get_inputs(op, only_differentiable): 231 op_inputs = op.inputs 232 if only_differentiable: 233 return op_inputs if is_differentiable(op) else [] 234 else: 235 return op_inputs 236 237 238def get_backward_walk_ops(seed_ops, 239 inclusive=True, 240 within_ops=None, 241 within_ops_fn=None, 242 stop_at_ts=(), 243 control_inputs=False, 244 only_differentiable=False): 245 """Do a backward graph walk and return all the visited ops. 246 247 Args: 248 seed_ops: an iterable of operations from which the backward graph 249 walk starts. If a list of tensors is given instead, the seed_ops are set 250 to be the generators of those tensors. 251 inclusive: if True the given seed_ops are also part of the resulting set. 252 within_ops: an iterable of `tf.Operation` within which the search is 253 restricted. If `within_ops` is `None`, the search is performed within 254 the whole graph. 255 within_ops_fn: if provided, a function on ops that should return True iff 256 the op is within the graph traversal. This can be used along within_ops, 257 in which case an op is within if it is also in within_ops. 258 stop_at_ts: an iterable of tensors at which the graph walk stops. 259 control_inputs: if True, control inputs will be used while moving backward. 260 only_differentiable: if True, only traverse ops which are differentiable. 261 This includes natively differentiable ops, or ops with custom gradients. 262 Returns: 263 A Python set of all the `tf.Operation` behind `seed_ops`. 264 Raises: 265 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of 266 `tf.Operation`. 267 """ 268 control_inputs = control_inputs and (not only_differentiable) 269 270 if not is_iterable(seed_ops): 271 seed_ops = [seed_ops] 272 if not seed_ops: 273 return [] 274 if isinstance(seed_ops[0], ops.Tensor): 275 ts = make_list_of_t(seed_ops, allow_graph=False) 276 seed_ops = get_generating_ops(ts) 277 else: 278 seed_ops = make_list_of_op(seed_ops, allow_graph=False) 279 280 stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts)) 281 seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops)) 282 if within_ops: 283 within_ops = make_list_of_op(within_ops, allow_graph=False) 284 within_ops = object_identity.ObjectIdentitySet(within_ops) 285 seed_ops &= within_ops 286 287 def is_within(op): 288 return (within_ops is None or op in within_ops) and ( 289 within_ops_fn is None or within_ops_fn(op)) 290 291 result = list(seed_ops) 292 wave = set(seed_ops) 293 while wave: 294 new_wave = set() 295 for op in wave: 296 for new_t in _get_inputs(op, only_differentiable=only_differentiable): 297 if new_t in stop_at_ts: 298 continue 299 if new_t.op not in result and is_within(new_t.op): 300 new_wave.add(new_t.op) 301 if control_inputs: 302 for new_op in op.control_inputs: 303 if new_op not in result and is_within(new_op): 304 new_wave.add(new_op) 305 concatenate_unique(result, new_wave) 306 wave = new_wave 307 if not inclusive: 308 result = [op for op in result if op not in seed_ops] 309 return result 310 311 312class UnliftableError(Exception): 313 """Raised if a Tensor cannot be lifted from the graph.""" 314 315 # Prevent autograph from rewriting this error. 316 ag_pass_through = True 317 318 319def _as_operation(op_or_tensor): 320 if isinstance(op_or_tensor, ops.Tensor): 321 return op_or_tensor.op 322 return op_or_tensor 323 324 325def graph_inputs(op): 326 return [x.op for x in op.inputs] + list(op.control_inputs) 327 328 329def _path_from(from_op, tensor, sources): 330 """Find one path from `from_op` to `tensor`, ignoring `sources`. 331 332 Args: 333 from_op: A `tf.Operation`. 334 tensor: A `tf.Operation` or `tf.Tensor`. 335 sources: A list of `tf.Tensor`. 336 337 Returns: 338 A python string containing the path, or "??" if none is found. 339 """ 340 if isinstance(from_op, ops.Tensor): 341 from_op = from_op.op 342 343 visited_ops = set(x.op for x in sources) 344 ops_to_visit = [_as_operation(tensor)] 345 some_op_output = {} 346 while ops_to_visit: 347 op = ops_to_visit.pop() 348 if op in visited_ops: 349 continue 350 visited_ops.add(op) 351 if op == from_op: 352 path_op = op 353 path = [path_op] 354 final_op = _as_operation(tensor) 355 while path_op != final_op: 356 path_op = some_op_output[path_op] 357 path.append(path_op) 358 return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path)) 359 else: 360 for inp in graph_inputs(op): 361 if inp not in visited_ops and inp not in sources: 362 some_op_output[inp] = op 363 ops_to_visit.append(inp) 364 return "??" 365 366 367# TODO(jmenick) - there is considerable duplication of functionality between 368# this function and get_backward_walk_ops(). Need to deduplicate. 369def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, 370 op_outputs, add_sources): 371 """Walk a Graph and capture the subgraph between init_tensor and sources. 372 373 Note: This function mutates visited_ops and op_outputs. 374 375 Args: 376 init_tensor: A Tensor or Operation where the subgraph terminates. 377 sources: A set of Tensors where subgraph extraction should stop. 378 disallowed_placeholders: An optional set of ops which may not appear in the 379 lifted graph. Defaults to all placeholders. 380 visited_ops: A set of operations which were visited in a prior pass. 381 op_outputs: A defaultdict containing the outputs of an op which are to be 382 copied into the new subgraph. 383 add_sources: A boolean indicating whether placeholders which are not in 384 sources should be allowed. 385 386 Returns: 387 The set of placeholders upon which init_tensor depends and are not in 388 sources. 389 390 Raises: 391 UnliftableError: if init_tensor depends on a placeholder which is not in 392 sources and add_sources is False. 393 """ 394 ops_to_visit = [_as_operation(init_tensor)] 395 extra_sources = object_identity.ObjectIdentitySet() 396 while ops_to_visit: 397 op = ops_to_visit.pop() 398 if op in visited_ops: 399 continue 400 visited_ops.add(op) 401 402 should_raise = False 403 if disallowed_placeholders is not None and op in disallowed_placeholders: 404 should_raise = True 405 elif op.type == "Placeholder": 406 if disallowed_placeholders is None and not add_sources: 407 should_raise = True 408 extra_sources.update(op.outputs) 409 410 if should_raise: 411 raise UnliftableError( 412 "Unable to lift tensor %s because it depends transitively on " 413 "placeholder %s via at least one path, e.g.: %s" 414 % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources))) 415 for inp in graph_inputs(op): 416 op_outputs[inp].add(op) 417 if inp not in visited_ops and inp not in (sources or extra_sources): 418 ops_to_visit.append(inp) 419 420 return extra_sources 421