1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SubGraphView: a subgraph view on an existing tf.Graph. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23 24import six 25from six import iteritems 26from six import StringIO 27 28from tensorflow.contrib.graph_editor import select 29from tensorflow.contrib.graph_editor import util 30from tensorflow.python.framework import ops as tf_ops 31 32__all__ = [ 33 "SubGraphView", 34 "make_view", 35 "make_view_from_scope", 36] 37 38 39def _finalize_index(index_or_t, ts): 40 """Returns index as is or return index of tensor in `ts`.""" 41 if isinstance(index_or_t, six.integer_types): 42 return index_or_t 43 else: 44 return ts.index(index_or_t) 45 46 47def _finalize_indices(list_of_index_or_t, ts): 48 """Returns index in `indices` as is or replace with tensor's index.""" 49 return [_finalize_index(index_or_t, ts) for index_or_t in list_of_index_or_t] 50 51 52def _check_within_range(mapping, n, repetition): 53 """Check is the mapping is valid. 54 55 Args: 56 mapping: an iterable of integer. 57 n: define the input domain as [0, n-1]. Note that the mapping can be 58 under-complete, that is, it can only contain a subset of the integers on 59 [0, n-1]. 60 repetition: if True repetition are allowed (the function is surjective) 61 otherwise repetition are not allowed (the function is injective). 62 Raises: 63 ValueError: if the mapping is out of range ot if repetition is False and 64 the mapping has some repetition. 65 """ 66 for i in mapping: 67 if not 0 <= i < n: 68 raise ValueError("Out of [0, {}[ range: {}".format(n, i)) 69 if not repetition and len(set(mapping)) != len(mapping): 70 raise ValueError("Found repetition in mapping: {}".format(mapping)) 71 72 73class SubGraphView(object): 74 """A subgraph view on an existing `tf.Graph`. 75 76 An instance of this class is a subgraph view on an existing `tf.Graph`. 77 "subgraph" means that it can represent part of the whole `tf.Graph`. 78 "view" means that it only provides a passive observation and do not to act 79 on the `tf.Graph`. Note that in this documentation, the term "subgraph" is 80 often used as substitute to "subgraph view". 81 82 A subgraph contains: 83 84 * a list of input tensors, accessible via the `inputs` property. 85 * a list of output tensors, accessible via the `outputs` property. 86 * and the operations in between, accessible via the "ops" property. 87 88 An subgraph can be seen as a function F(i0, i1, ...) -> o0, o1, ... It is a 89 function which takes as input some input tensors and returns as output some 90 output tensors. The computation that the function performs is encoded in the 91 operations of the subgraph. 92 93 The tensors (input or output) can be of two kinds: 94 95 - connected: a connected tensor connects to at least one operation contained 96 in the subgraph. One example is a subgraph representing a single operation 97 and its inputs and outputs: all the input and output tensors of the op 98 are "connected". 99 - passthrough: a passthrough tensor does not connect to any operation 100 contained in the subgraph. One example is a subgraph representing a 101 single tensor: this tensor is passthrough. By default a passthrough tensor is 102 present both in the input and output tensors of the subgraph. It can however 103 be remapped to only appear as an input (or output) only. 104 105 The input and output tensors can be remapped. For instance, some input tensor 106 can be omitted. For instance, a subgraph representing an operation with two 107 inputs can be remapped to only take one input. Note that this does not change 108 at all the underlying `tf.Graph` (remember, it is a view). It means that 109 the other input is being ignored, or is being treated as "given". 110 The analogy with functions can be extended like this: F(x,y) is the original 111 function. Remapping the inputs from [x, y] to just [x] means that the subgraph 112 now represent the function F_y(x) (y is "given"). 113 114 The output tensors can also be remapped. For instance, some output tensor can 115 be omitted. Other output tensor can be duplicated as well. As mentioned 116 before, this does not change at all the underlying `tf.Graph`. 117 The analogy with functions can be extended like this: F(...)->x,y is the 118 original function. Remapping the outputs from [x, y] to just [y,y] means that 119 the subgraph now represent the function M(F(...)) where M is the function 120 M(a,b)->b,b. 121 122 It is useful to describe three other kind of tensors: 123 124 * internal: an internal tensor is a tensor connecting operations contained 125 in the subgraph. One example in the subgraph representing the two 126 operations A and B connected sequentially: -> A -> B ->. The middle arrow 127 is an internal tensor. 128 * actual input: an input tensor of the subgraph, regardless of whether it is 129 listed in "inputs" or not (masked-out). 130 * actual output: an output tensor of the subgraph, regardless of whether it is 131 listed in "outputs" or not (masked-out). 132 * hidden input: an actual input which has been masked-out using an 133 input remapping. In other word, a hidden input is a non-internal tensor 134 not listed as a input tensor and one of whose consumers belongs to 135 the subgraph. 136 * hidden output: a actual output which has been masked-out using an output 137 remapping. In other word, a hidden output is a non-internal tensor 138 not listed as an output and one of whose generating operations belongs to 139 the subgraph. 140 141 Here are some useful guarantees about an instance of a SubGraphView: 142 143 * the input (or output) tensors are not internal. 144 * the input (or output) tensors are either "connected" or "passthrough". 145 * the passthrough tensors are not connected to any of the operation of 146 the subgraph. 147 148 Note that there is no guarantee that an operation in a subgraph contributes 149 at all to its inputs or outputs. For instance, remapping both the inputs and 150 outputs to empty lists will produce a subgraph which still contains all the 151 original operations. However, the remove_unused_ops function can be used to 152 make a new subgraph view whose operations are connected to at least one of 153 the input or output tensors. 154 155 An instance of this class is meant to be a lightweight object which is not 156 modified in-place by the user. Rather, the user can create new modified 157 instances of a given subgraph. In that sense, the class SubGraphView is meant 158 to be used like an immutable python object. 159 160 A common problem when using views is that they can get out-of-sync with the 161 data they observe (in this case, a `tf.Graph`). This is up to the user to 162 ensure that this doesn't happen. To keep on the safe side, it is recommended 163 that the life time of subgraph views are kept very short. One way to achieve 164 this is to use subgraphs within a "with make_sgv(...) as sgv:" Python context. 165 166 To alleviate the out-of-sync problem, some functions are granted the right to 167 modified subgraph in place. This is typically the case of graph manipulation 168 functions which, given some subgraphs as arguments, can modify the underlying 169 `tf.Graph`. Since this modification is likely to render the subgraph view 170 invalid, those functions can modify the argument in place to reflect the 171 change. For instance, calling the function swap_inputs(svg0, svg1) will modify 172 svg0 and svg1 in place to reflect the fact that their inputs have now being 173 swapped. 174 """ 175 176 def __init__(self, inside_ops=(), passthrough_ts=()): 177 """Create a subgraph containing the given ops and the "passthrough" tensors. 178 179 Args: 180 inside_ops: an object convertible to a list of `tf.Operation`. This list 181 defines all the operations in the subgraph. 182 passthrough_ts: an object convertible to a list of `tf.Tensor`. This list 183 define all the "passthrough" tensors. A passthrough tensor is a tensor 184 which goes directly from the input of the subgraph to it output, without 185 any intermediate operations. All the non passthrough tensors are 186 silently ignored. 187 Raises: 188 TypeError: if inside_ops cannot be converted to a list of `tf.Operation` 189 or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`. 190 """ 191 192 inside_ops = util.make_list_of_op(inside_ops) 193 passthrough_ts = util.make_list_of_t(passthrough_ts) 194 ops_and_ts = inside_ops + passthrough_ts 195 if ops_and_ts: 196 self._graph = util.get_unique_graph(ops_and_ts) 197 self._ops = inside_ops 198 199 # Compute inside and outside tensor 200 inputs, outputs, insides = select.compute_boundary_ts(inside_ops) 201 202 # Compute passthrough tensors, silently ignoring the non-passthrough ones. 203 all_tensors = frozenset(inputs + outputs + list(insides)) 204 self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] 205 206 # Set inputs and outputs. 207 self._input_ts = inputs + self._passthrough_ts 208 self._output_ts = outputs + self._passthrough_ts 209 else: 210 self._graph = None 211 self._passthrough_ts = [] 212 self._input_ts = [] 213 self._output_ts = [] 214 self._ops = [] 215 216 def __copy__(self): 217 """Create a copy of this subgraph. 218 219 Note that this class is a "view", copying it only create another view and 220 does not copy the underlying part of the `tf.Graph`. 221 222 Returns: 223 A new identical instance of the original subgraph view. 224 """ 225 cls = self.__class__ 226 result = cls.__new__(cls) 227 for k, v in iteritems(self.__dict__): 228 if k == "_graph": 229 setattr(result, k, v) 230 else: 231 setattr(result, k, list(v)) # copy the list 232 return result 233 234 def _assign_from(self, other): 235 """Assign other to itself. 236 237 Args: 238 other: another subgraph-view. 239 Returns: 240 A new instance identical to the original one. 241 Raises: 242 TypeError: if other is not an SubGraphView. 243 """ 244 if not isinstance(other, SubGraphView): 245 raise TypeError("Expected SubGraphView, got: {}".format(type(other))) 246 # pylint: disable=protected-access 247 self._graph = other._graph 248 self._ops = list(other._ops) 249 self._passthrough_ts = list(other._passthrough_ts) 250 self._input_ts = list(other._input_ts) 251 self._output_ts = list(other._output_ts) 252 # pylint: enable=protected-access 253 254 def copy(self): 255 """Return a copy of itself. 256 257 Note that this class is a "view", copying it only create another view and 258 does not copy the underlying part of the tf.Graph. 259 260 Returns: 261 A new instance identical to the original one. 262 """ 263 return copy.copy(self) 264 265 def _remap_default(self, remove_input_map=True, remove_output_map=True): 266 """Remap in the place the inputs and/or outputs to the default mapping. 267 268 Args: 269 remove_input_map: if True the input map is reset to the default one. 270 remove_output_map: if True the output map is reset to the default one. 271 """ 272 if not remove_input_map and not remove_output_map: 273 return 274 275 # Compute inside and outside tensor 276 inputs, outputs, _ = select.compute_boundary_ts(self._ops) 277 if remove_input_map: 278 self._input_ts = list(inputs) + self._passthrough_ts 279 if remove_output_map: 280 self._output_ts = list(outputs) + self._passthrough_ts 281 282 def remap_default(self, remove_input_map=True, remove_output_map=True): 283 """Remap the inputs and/or outputs to the default mapping. 284 285 Args: 286 remove_input_map: if True the input map is reset to the default one. 287 remove_output_map: if True the output map is reset to the default one. 288 Returns: 289 A new modified instance of the original subgraph view with its 290 input and/or output mapping reset to the default one. 291 """ 292 res = self.copy() 293 res._remap_default(remove_input_map, remove_output_map) # pylint: disable=protected-access 294 return res 295 296 def _remap_inputs(self, new_input_indices): 297 """Remap the inputs of the subgraph in-place.""" 298 new_input_indices = _finalize_indices(new_input_indices, self._input_ts) 299 _check_within_range( 300 new_input_indices, len(self._input_ts), repetition=False) 301 self._input_ts = [self._input_ts[i] for i in new_input_indices] 302 303 def _remap_outputs(self, new_output_indices): 304 """Remap the outputs of the subgraph in-place.""" 305 new_output_indices = _finalize_indices(new_output_indices, self._output_ts) 306 _check_within_range( 307 new_output_indices, len(self._output_ts), repetition=True) 308 self._output_ts = [self._output_ts[i] for i in new_output_indices] 309 310 def _remap_outputs_make_unique(self): 311 """Remap the outputs in place so that all the tensors appears only once.""" 312 output_ts = list(self._output_ts) 313 self._output_ts = [] 314 util.concatenate_unique(self._output_ts, output_ts) 315 316 def _remap_outputs_to_consumers(self): 317 """Remap the outputs in place to match the number of consumers.""" 318 self._remap_outputs_make_unique() 319 output_ts = list(self._output_ts) 320 self._output_ts = [] 321 for t in output_ts: 322 self._output_ts += [t] * len(t.consumers()) 323 324 def remap_outputs_make_unique(self): 325 """Remap the outputs so that all the tensors appears only once.""" 326 res = copy.copy(self) 327 res._remap_outputs_make_unique() # pylint: disable=protected-access 328 return res 329 330 def remap_outputs_to_consumers(self): 331 """Remap the outputs to match the number of consumers.""" 332 res = copy.copy(self) 333 res._remap_outputs_to_consumers() # pylint: disable=protected-access 334 return res 335 336 def _remove_unused_ops(self, control_inputs=True): 337 """Remove unused ops in place. 338 339 Args: 340 control_inputs: if True, control inputs are used to detect used ops. 341 Returns: 342 A new subgraph view which only contains used operations. 343 """ 344 ops = select.get_walks_union_ops( 345 self.connected_inputs, 346 self.connected_outputs, 347 within_ops=self._ops, 348 control_inputs=control_inputs) 349 self._ops = [op for op in self._ops if op in ops] 350 351 def remove_unused_ops(self, control_inputs=True): 352 """Remove unused ops. 353 354 Args: 355 control_inputs: if True, control inputs are used to detect used ops. 356 Returns: 357 A new subgraph view which only contains used operations. 358 """ 359 res = copy.copy(self) 360 res._remove_unused_ops(control_inputs) # pylint: disable=protected-access 361 return res 362 363 def remap_inputs(self, new_input_indices): 364 """Remap the inputs of the subgraph. 365 366 If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0] 367 will create a new instance whose inputs is [t2, t0]. 368 369 Note that this is only modifying the view: the underlying `tf.Graph` is not 370 affected. 371 372 Args: 373 new_input_indices: an iterable of integers or tf.Tensors 374 representing a mapping between the old inputs and the new ones. 375 Integers must be positive and smaller than the number of old inputs. 376 tf.Tensors must belong to the old list of inputs. 377 This mapping can be under-complete and must be without repetitions. 378 Returns: 379 A new modified instance of the original subgraph view with remapped 380 inputs. 381 """ 382 res = self.copy() 383 res._remap_inputs(new_input_indices) # pylint: disable=protected-access 384 return res 385 386 def remap_outputs(self, new_output_indices): 387 """Remap the output of the subgraph. 388 389 If the output of the original subgraph are [t0, t1, t2], remapping to 390 [1,1,0] will create a new instance whose outputs is [t1, t1, t0]. 391 392 Note that this is only modifying the view: the underlying tf.Graph is not 393 affected. 394 395 Args: 396 new_output_indices: an iterable of integers or tf.Tensors 397 representing a mapping between the old outputs and the new ones. 398 Integers must be positive and smaller than the number of old outputs. 399 tf.Tensors must belong to the old list of outputs. 400 This mapping can be under-complete and can have repetitions. 401 Returns: 402 A new modified instance of the original subgraph view with remapped 403 outputs. 404 """ 405 res = copy.copy(self) 406 res._remap_outputs(new_output_indices) # pylint: disable=protected-access 407 return res 408 409 def remap(self, new_input_indices=None, new_output_indices=None): 410 """Remap the inputs and outputs of the subgraph. 411 412 Note that this is only modifying the view: the underlying tf.Graph is not 413 affected. 414 415 Args: 416 new_input_indices: an iterable of integers or tf.Tensors 417 representing a mapping between the old inputs and the new ones. 418 Integers must be positive and smaller than the number of old inputs. 419 tf.Tensors must belong to the old list of inputs. 420 This mapping can be under-complete and must be without repetitions. 421 new_output_indices: an iterable of integers or tf.Tensors 422 representing a mapping between the old outputs and the new ones. 423 Integers must be positive and smaller than the number of old outputs. 424 tf.Tensors must belong to the old list of outputs. 425 This mapping can be under-complete and can have repetitions. 426 Returns: 427 A new modified instance of the original subgraph view with remapped 428 inputs and outputs. 429 """ 430 res = copy.copy(self) 431 if new_input_indices is not None: 432 res._remap_inputs(new_input_indices) # pylint: disable=protected-access 433 if new_output_indices is not None: 434 res._remap_outputs(new_output_indices) # pylint: disable=protected-access 435 return res 436 437 def find_op_by_name(self, op_name): 438 """Return the op named op_name. 439 440 Args: 441 op_name: the name to search for 442 Returns: 443 The op named op_name. 444 Raises: 445 ValueError: if the op_name could not be found. 446 AssertionError: if the name was found multiple time. 447 """ 448 res = [op for op in self._ops if op.name == op_name] 449 if not res: 450 raise ValueError("{} not in subgraph.".format(op_name)) 451 if len(res) > 1: 452 raise AssertionError("More than 1 op named: {}!".format(op_name)) 453 return res[0] 454 455 def __str__(self): 456 if not self: 457 return "SubGraphView: empty" 458 459 def op_name(op): 460 return op.name 461 462 def tensor_name(t): 463 if t in self._passthrough_ts: 464 return "{} *".format(t.name) 465 else: 466 return t.name 467 468 def print_list(name, iterable, get_name): 469 if iterable: 470 print("** {}[{}]:".format(name, len(iterable)), file=res) 471 print("\n".join([" {}".format(get_name(elem)) for elem in iterable]), 472 file=res) 473 else: 474 print("** {}: empty".format(name), file=res) 475 476 res = StringIO() 477 print("SubGraphView (graphid={}):".format(id(self.graph)), file=res) 478 print_list("ops", self._ops, op_name) 479 print_list("inputs", self._input_ts, tensor_name) 480 print_list("outputs", self._output_ts, tensor_name) 481 return res.getvalue() 482 483 @property 484 def graph(self): 485 """The underlying `tf.Graph`.""" 486 return self._graph 487 488 @property 489 def ops(self): 490 """The operations in this subgraph view.""" 491 return self._ops 492 493 @property 494 def inputs(self): 495 """The input tensors of this subgraph view.""" 496 return util.ListView(self._input_ts) 497 498 @property 499 def connected_inputs(self): 500 """The connected input tensors of this subgraph view.""" 501 return [t for t in self._input_ts if t not in self._passthrough_ts] 502 503 @property 504 def outputs(self): 505 """The output tensors of this subgraph view.""" 506 return util.ListView(self._output_ts) 507 508 @property 509 def connected_outputs(self): 510 """The connected output tensors of this subgraph view.""" 511 return [t for t in self._output_ts if t not in self._passthrough_ts] 512 513 @property 514 def passthroughs(self): 515 """The passthrough tensors, going straight from input to output.""" 516 return util.ListView(self._passthrough_ts) 517 518 def __bool__(self): 519 """Allows for implicit boolean conversion.""" 520 return self._graph is not None 521 522 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 523 __nonzero__ = __bool__ 524 525 def op(self, op_id): 526 """Get an op by its index.""" 527 return self._ops[op_id] 528 529 def is_passthrough(self, t): 530 """Check whether a tensor is passthrough.""" 531 return t in self._passthrough_ts 532 533 def __enter__(self): 534 """Allow Python context to minimize the life time of a subgraph view. 535 536 A subgraph view is meant to be a lightweight and transient object. A short 537 lifetime will alleviate the "out-of-sync" issue mentioned earlier. For that 538 reason, a SubGraphView instance can be used within a Python context. For 539 example: 540 541 from tensorflow.contrib import graph_editor as ge 542 with ge.make_sgv(...) as sgv: 543 print(sgv) 544 545 Returns: 546 Itself. 547 """ 548 return self 549 550 def __exit__(self, exc_type, exc_value, traceback): 551 pass 552 553 def input_index(self, t): 554 """Find the input index corresponding to the given input tensor t. 555 556 Args: 557 t: the input tensor of this subgraph view. 558 Returns: 559 The index in the self.inputs list. 560 Raises: 561 Error: if t in not an input tensor. 562 """ 563 try: 564 subgraph_id = self._input_ts.index(t) 565 except: 566 raise ValueError("Can't find {} in inputs of subgraph {}.".format( 567 t.name, self.name)) 568 return subgraph_id 569 570 def output_index(self, t): 571 """Find the output index corresponding to given output tensor t. 572 573 Args: 574 t: the output tensor of this subgraph view. 575 Returns: 576 The index in the self.outputs list. 577 Raises: 578 Error: if t in not an output tensor. 579 """ 580 try: 581 subgraph_id = self._output_ts.index(t) 582 except: 583 raise ValueError("Can't find {} in outputs of subgraph {}.".format( 584 t.name, self.name)) 585 return subgraph_id 586 587 def consumers(self): 588 """Return a Python set of all the consumers of this subgraph view. 589 590 A consumer of a subgraph view is a tf.Operation which is a consumer 591 of one of the output tensors and is not in the subgraph. 592 593 Returns: 594 A list of `tf.Operation` which are the consumers of this subgraph view. 595 """ 596 ops_set = frozenset(self._ops) 597 res = [] 598 for output in self._output_ts: 599 consumers = [op for op in output.consumers() if op not in ops_set] 600 util.concatenate_unique(res, consumers) 601 return res 602 603 604def _check_graph(sgv, graph): 605 """Check if sgv belongs to the given graph. 606 607 Args: 608 sgv: a SubGraphView. 609 graph: a graph or None. 610 Returns: 611 The SubGraphView sgv. 612 Raises: 613 TypeError: if sgv is not a SubGraphView or if graph is not None and not 614 a tf.Graph. 615 ValueError: if the graph of sgv and the given graph are not None and 616 different. 617 """ 618 if not isinstance(sgv, SubGraphView): 619 raise TypeError("Expected a SubGraphView, got: {}".format(type(graph))) 620 if graph is None or not sgv.graph: 621 return sgv 622 if not isinstance(graph, tf_ops.Graph): 623 raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) 624 if sgv.graph is not graph: 625 raise ValueError("Graph mismatch.") 626 return sgv 627 628 629def make_view(*args, **kwargs): 630 """Create a SubGraphView from selected operations and passthrough tensors. 631 632 Args: 633 *args: list of 1) regular expressions (compiled or not) or 2) (array of) 634 `tf.Operation` 3) (array of) `tf.Tensor`. Those objects will be converted 635 into a list of operations and a list of candidate for passthrough tensors. 636 **kwargs: keyword graph is used 1) to check that the ops and ts are from 637 the correct graph 2) for regular expression query 638 Returns: 639 A subgraph view. 640 Raises: 641 TypeError: if the optional keyword argument graph is not a `tf.Graph` 642 or if an argument in args is not an (array of) `tf.Tensor` 643 or an (array of) `tf.Operation` or a string or a regular expression. 644 ValueError: if one of the keyword arguments is unexpected. 645 """ 646 # get keywords arguments 647 graph = kwargs["graph"] if "graph" in kwargs else None 648 649 # already a view? 650 if len(args) == 1 and isinstance(args[0], SubGraphView): 651 return _check_graph(args[0], graph) 652 653 ops, ts = select.select_ops_and_ts(*args, **kwargs) 654 sgv = SubGraphView(ops, ts) 655 return _check_graph(sgv, graph) 656 657 658def make_view_from_scope(scope, graph): 659 """Make a subgraph from a name scope. 660 661 Args: 662 scope: the name of the scope. 663 graph: the `tf.Graph`. 664 Returns: 665 A subgraph view representing the given scope. 666 """ 667 ops = select.get_name_scope_ops(graph, scope) 668 return SubGraphView(ops) 669