1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for the graph_editor. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import re 24from six import iteritems 25from tensorflow.python.framework import ops as tf_ops 26from tensorflow.python.ops import array_ops as tf_array_ops 27 28__all__ = [ 29 "make_list_of_op", 30 "get_tensors", 31 "make_list_of_t", 32 "get_generating_ops", 33 "get_consuming_ops", 34 "ControlOutputs", 35 "placeholder_name", 36 "make_placeholder_from_tensor", 37 "make_placeholder_from_dtype_and_shape", 38] 39 40 41# The graph editor sometimes need to create placeholders, they are named 42# "geph_*". "geph" stands for Graph-Editor PlaceHolder. 43_DEFAULT_PLACEHOLDER_PREFIX = "geph" 44 45 46def concatenate_unique(la, lb): 47 """Add all the elements of `lb` to `la` if they are not there already. 48 49 The elements added to `la` maintain ordering with respect to `lb`. 50 51 Args: 52 la: List of Python objects. 53 lb: List of Python objects. 54 Returns: 55 `la`: The list `la` with missing elements from `lb`. 56 """ 57 la_set = set(la) 58 for l in lb: 59 if l not in la_set: 60 la.append(l) 61 la_set.add(l) 62 return la 63 64 65# TODO(fkp): very generic code, it should be moved in a more generic place. 66class ListView(object): 67 """Immutable list wrapper. 68 69 This class is strongly inspired by the one in tf.Operation. 70 """ 71 72 def __init__(self, list_): 73 if not isinstance(list_, list): 74 raise TypeError("Expected a list, got: {}.".format(type(list_))) 75 self._list = list_ 76 77 def __iter__(self): 78 return iter(self._list) 79 80 def __len__(self): 81 return len(self._list) 82 83 def __bool__(self): 84 return bool(self._list) 85 86 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 87 __nonzero__ = __bool__ 88 89 def __getitem__(self, i): 90 return self._list[i] 91 92 def __add__(self, other): 93 if not isinstance(other, list): 94 other = list(other) 95 return list(self) + other 96 97 98# TODO(fkp): very generic code, it should be moved in a more generic place. 99def is_iterable(obj): 100 """Return true if the object is iterable.""" 101 if isinstance(obj, tf_ops.Tensor): 102 return False 103 try: 104 _ = iter(obj) 105 except Exception: # pylint: disable=broad-except 106 return False 107 return True 108 109 110def flatten_tree(tree, leaves=None): 111 """Flatten a tree into a list. 112 113 Args: 114 tree: iterable or not. If iterable, its elements (child) can also be 115 iterable or not. 116 leaves: list to which the tree leaves are appended (None by default). 117 Returns: 118 A list of all the leaves in the tree. 119 """ 120 if leaves is None: 121 leaves = [] 122 if isinstance(tree, dict): 123 for _, child in iteritems(tree): 124 flatten_tree(child, leaves) 125 elif is_iterable(tree): 126 for child in tree: 127 flatten_tree(child, leaves) 128 else: 129 leaves.append(tree) 130 return leaves 131 132 133def transform_tree(tree, fn, iterable_type=tuple): 134 """Transform all the nodes of a tree. 135 136 Args: 137 tree: iterable or not. If iterable, its elements (child) can also be 138 iterable or not. 139 fn: function to apply to each leaves. 140 iterable_type: type use to construct the resulting tree for unknown 141 iterable, typically `list` or `tuple`. 142 Returns: 143 A tree whose leaves has been transformed by `fn`. 144 The hierarchy of the output tree mimics the one of the input tree. 145 """ 146 if is_iterable(tree): 147 if isinstance(tree, dict): 148 res = tree.__new__(type(tree)) 149 res.__init__( 150 (k, transform_tree(child, fn)) for k, child in iteritems(tree)) 151 return res 152 elif isinstance(tree, tuple): 153 # NamedTuple? 154 if hasattr(tree, "_asdict"): 155 res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn)) 156 else: 157 res = tree.__new__(type(tree), 158 (transform_tree(child, fn) for child in tree)) 159 return res 160 elif isinstance(tree, collections.Sequence): 161 res = tree.__new__(type(tree)) 162 res.__init__(transform_tree(child, fn) for child in tree) 163 return res 164 else: 165 return iterable_type(transform_tree(child, fn) for child in tree) 166 else: 167 return fn(tree) 168 169 170def check_graphs(*args): 171 """Check that all the element in args belong to the same graph. 172 173 Args: 174 *args: a list of object with a obj.graph property. 175 Raises: 176 ValueError: if all the elements do not belong to the same graph. 177 """ 178 graph = None 179 for i, sgv in enumerate(args): 180 if graph is None and sgv.graph is not None: 181 graph = sgv.graph 182 elif sgv.graph is not None and sgv.graph is not graph: 183 raise ValueError("Argument[{}]: Wrong graph!".format(i)) 184 185 186def get_unique_graph(tops, check_types=None, none_if_empty=False): 187 """Return the unique graph used by the all the elements in tops. 188 189 Args: 190 tops: list of elements to check (usually a list of tf.Operation and/or 191 tf.Tensor). Or a tf.Graph. 192 check_types: check that the element in tops are of given type(s). If None, 193 the types (tf.Operation, tf.Tensor) are used. 194 none_if_empty: don't raise an error if tops is an empty list, just return 195 None. 196 Returns: 197 The unique graph used by all the tops. 198 Raises: 199 TypeError: if tops is not a iterable of tf.Operation. 200 ValueError: if the graph is not unique. 201 """ 202 if isinstance(tops, tf_ops.Graph): 203 return tops 204 if not is_iterable(tops): 205 raise TypeError("{} is not iterable".format(type(tops))) 206 if check_types is None: 207 check_types = (tf_ops.Operation, tf_ops.Tensor) 208 elif not is_iterable(check_types): 209 check_types = (check_types,) 210 g = None 211 for op in tops: 212 if not isinstance(op, check_types): 213 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( 214 t) for t in check_types]), type(op))) 215 if g is None: 216 g = op.graph 217 elif g is not op.graph: 218 raise ValueError("Operation {} does not belong to given graph".format(op)) 219 if g is None and not none_if_empty: 220 raise ValueError("Can't find the unique graph of an empty list") 221 return g 222 223 224def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): 225 """Convert ops to a list of `tf.Operation`. 226 227 Args: 228 ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single 229 operation. 230 check_graph: if `True` check if all the operations belong to the same graph. 231 allow_graph: if `False` a `tf.Graph` cannot be converted. 232 ignore_ts: if True, silently ignore `tf.Tensor`. 233 Returns: 234 A newly created list of `tf.Operation`. 235 Raises: 236 TypeError: if ops cannot be converted to a list of `tf.Operation` or, 237 if `check_graph` is `True`, if all the ops do not belong to the 238 same graph. 239 """ 240 if isinstance(ops, tf_ops.Graph): 241 if allow_graph: 242 return ops.get_operations() 243 else: 244 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 245 else: 246 if not is_iterable(ops): 247 ops = [ops] 248 if not ops: 249 return [] 250 if check_graph: 251 check_types = None if ignore_ts else tf_ops.Operation 252 get_unique_graph(ops, check_types=check_types) 253 return [op for op in ops if isinstance(op, tf_ops.Operation)] 254 255 256# TODO(fkp): move this function in tf.Graph? 257def get_tensors(graph): 258 """get all the tensors which are input or output of an op in the graph. 259 260 Args: 261 graph: a `tf.Graph`. 262 Returns: 263 A list of `tf.Tensor`. 264 Raises: 265 TypeError: if graph is not a `tf.Graph`. 266 """ 267 if not isinstance(graph, tf_ops.Graph): 268 raise TypeError("Expected a graph, got: {}".format(type(graph))) 269 ts = [] 270 for op in graph.get_operations(): 271 ts += op.outputs 272 return ts 273 274 275def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): 276 """Convert ts to a list of `tf.Tensor`. 277 278 Args: 279 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. 280 check_graph: if `True` check if all the tensors belong to the same graph. 281 allow_graph: if `False` a `tf.Graph` cannot be converted. 282 ignore_ops: if `True`, silently ignore `tf.Operation`. 283 Returns: 284 A newly created list of `tf.Tensor`. 285 Raises: 286 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, 287 if `check_graph` is `True`, if all the ops do not belong to the same graph. 288 """ 289 if isinstance(ts, tf_ops.Graph): 290 if allow_graph: 291 return get_tensors(ts) 292 else: 293 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 294 else: 295 if not is_iterable(ts): 296 ts = [ts] 297 if not ts: 298 return [] 299 if check_graph: 300 check_types = None if ignore_ops else tf_ops.Tensor 301 get_unique_graph(ts, check_types=check_types) 302 return [t for t in ts if isinstance(t, tf_ops.Tensor)] 303 304 305def get_generating_ops(ts): 306 """Return all the generating ops of the tensors in `ts`. 307 308 Args: 309 ts: a list of `tf.Tensor` 310 Returns: 311 A list of all the generating `tf.Operation` of the tensors in `ts`. 312 Raises: 313 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. 314 """ 315 ts = make_list_of_t(ts, allow_graph=False) 316 return [t.op for t in ts] 317 318 319def get_consuming_ops(ts): 320 """Return all the consuming ops of the tensors in ts. 321 322 Args: 323 ts: a list of `tf.Tensor` 324 Returns: 325 A list of all the consuming `tf.Operation` of the tensors in `ts`. 326 Raises: 327 TypeError: if ts cannot be converted to a list of `tf.Tensor`. 328 """ 329 ts = make_list_of_t(ts, allow_graph=False) 330 ops = [] 331 for t in ts: 332 for op in t.consumers(): 333 if op not in ops: 334 ops.append(op) 335 return ops 336 337 338class ControlOutputs(object): 339 """The control outputs topology.""" 340 341 def __init__(self, graph): 342 """Create a dictionary of control-output dependencies. 343 344 Args: 345 graph: a `tf.Graph`. 346 Returns: 347 A dictionary where a key is a `tf.Operation` instance and the 348 corresponding value is a list of all the ops which have the key 349 as one of their control-input dependencies. 350 Raises: 351 TypeError: graph is not a `tf.Graph`. 352 """ 353 if not isinstance(graph, tf_ops.Graph): 354 raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) 355 self._control_outputs = {} 356 self._graph = graph 357 self._version = None 358 self._build() 359 360 def update(self): 361 """Update the control outputs if the graph has changed.""" 362 if self._version != self._graph.version: 363 self._build() 364 return self 365 366 def _build(self): 367 """Build the control outputs dictionary.""" 368 self._control_outputs.clear() 369 ops = self._graph.get_operations() 370 for op in ops: 371 for control_input in op.control_inputs: 372 if control_input not in self._control_outputs: 373 self._control_outputs[control_input] = [] 374 if op not in self._control_outputs[control_input]: 375 self._control_outputs[control_input].append(op) 376 self._version = self._graph.version 377 378 def get_all(self): 379 return self._control_outputs 380 381 def get(self, op): 382 """return the control outputs of op.""" 383 if op in self._control_outputs: 384 return self._control_outputs[op] 385 else: 386 return () 387 388 @property 389 def graph(self): 390 return self._graph 391 392 393def scope_finalize(scope): 394 if scope and scope[-1] != "/": 395 scope += "/" 396 return scope 397 398 399def scope_dirname(scope): 400 slash = scope.rfind("/") 401 if slash == -1: 402 return "" 403 return scope[:slash + 1] 404 405 406def scope_basename(scope): 407 slash = scope.rfind("/") 408 if slash == -1: 409 return scope 410 return scope[slash + 1:] 411 412 413def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): 414 """Create placeholder name for the graph editor. 415 416 Args: 417 t: optional tensor on which the placeholder operation's name will be based 418 on 419 scope: absolute scope with which to prefix the placeholder's name. None 420 means that the scope of t is preserved. "" means the root scope. 421 prefix: placeholder name prefix. 422 Returns: 423 A new placeholder name prefixed by "geph". Note that "geph" stands for 424 Graph Editor PlaceHolder. This convention allows to quickly identify the 425 placeholder generated by the Graph Editor. 426 Raises: 427 TypeError: if t is not None or a tf.Tensor. 428 """ 429 if scope is not None: 430 scope = scope_finalize(scope) 431 if t is not None: 432 if not isinstance(t, tf_ops.Tensor): 433 raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t))) 434 op_dirname = scope_dirname(t.op.name) 435 op_basename = scope_basename(t.op.name) 436 if scope is None: 437 scope = op_dirname 438 439 if op_basename.startswith("{}__".format(prefix)): 440 ph_name = op_basename 441 else: 442 ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) 443 444 return scope + ph_name 445 else: 446 if scope is None: 447 scope = "" 448 return "{}{}".format(scope, prefix) 449 450 451def make_placeholder_from_tensor(t, scope=None, 452 prefix=_DEFAULT_PLACEHOLDER_PREFIX): 453 """Create a `tf.placeholder` for the Graph Editor. 454 455 Note that the correct graph scope must be set by the calling function. 456 457 Args: 458 t: a `tf.Tensor` whose name will be used to create the placeholder 459 (see function placeholder_name). 460 scope: absolute scope within which to create the placeholder. None 461 means that the scope of `t` is preserved. `""` means the root scope. 462 prefix: placeholder name prefix. 463 Returns: 464 A newly created `tf.placeholder`. 465 Raises: 466 TypeError: if `t` is not `None` or a `tf.Tensor`. 467 """ 468 return tf_array_ops.placeholder( 469 dtype=t.dtype, shape=t.get_shape(), 470 name=placeholder_name(t, scope=scope, prefix=prefix)) 471 472 473def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, 474 prefix=_DEFAULT_PLACEHOLDER_PREFIX): 475 """Create a tf.placeholder for the Graph Editor. 476 477 Note that the correct graph scope must be set by the calling function. 478 The placeholder is named using the function placeholder_name (with no 479 tensor argument). 480 481 Args: 482 dtype: the tensor type. 483 shape: the tensor shape (optional). 484 scope: absolute scope within which to create the placeholder. None 485 means that the scope of t is preserved. "" means the root scope. 486 prefix: placeholder name prefix. 487 Returns: 488 A newly created tf.placeholder. 489 """ 490 return tf_array_ops.placeholder( 491 dtype=dtype, shape=shape, 492 name=placeholder_name(scope=scope, prefix=prefix)) 493 494 495_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") 496 497 498def get_predefined_collection_names(): 499 """Return all the predefined collection names.""" 500 return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) 501 if not _INTERNAL_VARIABLE_RE.match(key)] 502 503 504def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): 505 """Find corresponding op/tensor in a different graph. 506 507 Args: 508 target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. 509 dst_graph: The graph in which the corresponding graph element must be found. 510 dst_scope: A scope which is prepended to the name to look for. 511 src_scope: A scope which is removed from the original of `target` name. 512 513 Returns: 514 The corresponding tf.Tensor` or a `tf.Operation`. 515 516 Raises: 517 ValueError: if `src_name` does not start with `src_scope`. 518 TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` 519 KeyError: If the corresponding graph element cannot be found. 520 """ 521 src_name = target.name 522 if src_scope: 523 src_scope = scope_finalize(src_scope) 524 if not src_name.startswidth(src_scope): 525 raise ValueError("{} does not start with {}".format(src_name, src_scope)) 526 src_name = src_name[len(src_scope):] 527 528 dst_name = src_name 529 if dst_scope: 530 dst_scope = scope_finalize(dst_scope) 531 dst_name = dst_scope + dst_name 532 533 if isinstance(target, tf_ops.Tensor): 534 return dst_graph.get_tensor_by_name(dst_name) 535 if isinstance(target, tf_ops.Operation): 536 return dst_graph.get_operation_by_name(dst_name) 537 raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) 538 539 540def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): 541 """Find corresponding ops/tensors in a different graph. 542 543 `targets` is a Python tree, that is, a nested structure of iterable 544 (list, tupple, dictionary) whose leaves are instances of 545 `tf.Tensor` or `tf.Operation` 546 547 Args: 548 targets: A Python tree containing `tf.Tensor` or `tf.Operation` 549 belonging to the original graph. 550 dst_graph: The graph in which the corresponding graph element must be found. 551 dst_scope: A scope which is prepended to the name to look for. 552 src_scope: A scope which is removed from the original of `top` name. 553 554 Returns: 555 A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. 556 557 Raises: 558 ValueError: if `src_name` does not start with `src_scope`. 559 TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` 560 KeyError: If the corresponding graph element cannot be found. 561 """ 562 def func(top): 563 return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) 564 return transform_tree(targets, func) 565