1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class to transform an subgraph into another. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from copy import deepcopy 23from functools import partial 24from six import iteritems 25from six import string_types 26from six import StringIO 27from tensorflow.contrib.graph_editor import reroute 28from tensorflow.contrib.graph_editor import select 29from tensorflow.contrib.graph_editor import subgraph 30from tensorflow.contrib.graph_editor import util 31from tensorflow.python.framework import ops as tf_ops 32from tensorflow.python.platform import tf_logging as logging 33 34 35__all__ = [ 36 "replace_t_with_placeholder_handler", 37 "keep_t_if_possible_handler", 38 "assign_renamed_collections_handler", 39 "transform_op_if_inside_handler", 40 "copy_op_handler", 41 "Transformer", 42 "TransformerInfo", 43 "copy", 44 "copy_with_input_replacements", 45 "graph_replace", 46] 47 48 49def replace_t_with_placeholder_handler(info, t): 50 """Transform a tensor into a placeholder tensor. 51 52 This handler is typically used to transform a subgraph input tensor into a 53 placeholder. 54 55 Args: 56 info: Transform._TmpInfo instance. 57 t: tensor whose input must be transformed into a place holder. 58 Returns: 59 The tensor generated by the newly created place holder. 60 """ 61 with info.graph_.as_default(): 62 t_ = util.make_placeholder_from_tensor(t, scope=info.scope_) 63 return t_ 64 65 66def keep_t_if_possible_handler(info, t): 67 """Transform a tensor into itself (identity) if possible. 68 69 This handler transform a tensor into itself if the source and destination 70 graph are the same. Otherwise it will create a placeholder. 71 This handler is typically used to transform a hidden input tensors. 72 73 Args: 74 info: Transform._TmpInfo instance. 75 t: tensor whose input must be transformed into a place holder. 76 Returns: 77 The tensor generated by the newly created place holder. 78 """ 79 if info.graph is info.graph_: 80 return t 81 else: 82 return replace_t_with_placeholder_handler(info, t) 83 84 85def assign_renamed_collections_handler(info, elem, elem_): 86 """Add the transformed elem to the (renamed) collections of elem. 87 88 A collection is renamed only if is not a known key, as described in 89 `tf.GraphKeys`. 90 91 Args: 92 info: Transform._TmpInfo instance. 93 elem: the original element (`tf.Tensor` or `tf.Operation`) 94 elem_: the transformed element 95 """ 96 known_collection_names = util.get_predefined_collection_names() 97 for name, collection in iteritems(info.collections): 98 if elem not in collection: 99 continue 100 101 if name in known_collection_names: 102 transformed_name = name 103 else: 104 transformed_name = info.new_name(name) 105 info.graph_.add_to_collection(transformed_name, elem_) 106 107 108def transform_op_if_inside_handler(info, op, keep_if_possible=True): 109 """Transform an optional op only if it is inside the subgraph. 110 111 This handler is typically use to handle original op: it is fine to keep them 112 if they are inside the subgraph, otherwise they are just ignored. 113 114 Args: 115 info: Transform._TmpInfo instance. 116 op: the optional op to transform (or ignore). 117 keep_if_possible: re-attach to the original op if possible, that is, 118 if the source graph and the destination graph are the same. 119 Returns: 120 The transformed op or None. 121 """ 122 if op in info.sgv.ops: 123 return info.transformed_ops[op] 124 else: 125 if keep_if_possible and info.graph is info.graph_: 126 return op 127 else: 128 return None 129 130 131def copy_op_handler(info, op, new_inputs, copy_shape=False, nodedef_fn=None): 132 """Copy a `tf.Operation`. 133 134 Args: 135 info: Transform._TmpInfo instance. 136 op: the `tf.Operation` to be copied. 137 new_inputs: The new inputs for this op. 138 copy_shape: also copy the shape of the tensor 139 nodedef_fn: If provided, a function that will be run on the NodeDef 140 and should return a mutated NodeDef before a new Operation is created. 141 This is useful as certain features cannot be set on the Operation and 142 must be modified in NodeDef. 143 144 Returns: 145 A `(op, op_outputs)` tuple containing the transformed op and its outputs. 146 """ 147 # The `new_inputs` was added to this function. For compatibility reason, 148 # let's raise an error if `new_inputs` is a boolean. 149 if isinstance(new_inputs, bool): 150 raise TypeError("the `new_inputs` argument must be an iterable.") 151 152 # pylint: disable=protected-access 153 154 # Clone the node def: 155 node_def_ = deepcopy(op.node_def) 156 157 # Transform name: 158 name_ = info.new_name(op.name) 159 name_ = info.graph_.unique_name(name_) 160 node_def_.name = name_ 161 162 # Mutate NodeDef if requested: 163 if nodedef_fn is not None: 164 node_def_ = nodedef_fn(node_def_) 165 166 # Copy the other inputs needed for initialization 167 output_types_ = op._output_types[:] 168 input_types_ = op._input_types[:] 169 170 # Make a copy of the op_def too. 171 # Its unique to every _type_ of Operation. 172 op_def_ = deepcopy(op.op_def) 173 174 # Initialize a new Operation instance 175 op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, 176 [], input_types_, None, op_def_) 177 178 # copy the shape over 179 if copy_shape: 180 for t, t_ in zip(op.outputs, op_.outputs): 181 t_.set_shape(t.get_shape()) 182 183 # Original op cannot be finalised here yet. Because some ops require this 184 # attribute to exist, we will create a dummy original_op first and then 185 # later finalise it with the actual original_op when all the ops have 186 # been copied. 187 # TODO(fkp): Stop worrying about _original_op and remove this code? 188 if op._original_op: 189 op_._original_op = op._original_op 190 191 return op_, op_.outputs 192 193 194class TransformerInfo(object): 195 """"Contains information about the result of a transform operation.""" 196 197 def __init__(self, info): 198 """Constructor. 199 200 Args: 201 info: an instance of Transformer._TmpInfo containing various internal 202 information about the transform operation. 203 """ 204 self._graph = info.graph 205 self._scope = info.scope 206 self._graph_ = info.graph_ 207 self._scope_ = info.scope_ 208 self._transformed_ops = info.transformed_ops 209 self._transformed_ts = info.transformed_ts 210 211 def _get_transformed_map(self, top): 212 """Return the correct container depending on the type of `top`.""" 213 if isinstance(top, tf_ops.Operation): 214 return self._transformed_ops 215 elif isinstance(top, tf_ops.Tensor): 216 return self._transformed_ts 217 else: 218 raise TypeError( 219 "Expected a tf.Tensor or a tf.Operation, got a {}".format( 220 type(top))) 221 222 def _transformed_elem(self, original_top, missing_fn=None): 223 """Return the transformed op/tensor corresponding to the original one. 224 225 Args: 226 original_top: the original tensor/operation. 227 missing_fn: function handling the case where the counterpart 228 cannot be found. By default, None is returned. 229 Returns: 230 the transformed tensor/operation (or None if no match is found). 231 """ 232 transformed_map = self._get_transformed_map(original_top) 233 if isinstance(original_top, string_types): 234 for original, transformed in iteritems(transformed_map): 235 if original.name == original_top: 236 return transformed 237 return None if missing_fn is None else missing_fn(original_top) 238 else: 239 if original_top not in transformed_map: 240 return None if missing_fn is None else missing_fn(original_top) 241 return transformed_map[original_top] 242 243 def _original_elem(self, transformed_top, missing_fn=None): 244 """Return the original op/tensor corresponding to the transformed one. 245 246 Args: 247 transformed_top: the transformed tensor/operation. 248 missing_fn: function handling the case where the counterpart 249 cannot be found. By default, None is returned. 250 Returns: 251 the original tensor/operation (or None if no match is found). 252 """ 253 transformed_map = self._get_transformed_map(transformed_top) 254 if isinstance(transformed_top, string_types): 255 finder = lambda transformed: transformed.name == transformed_top 256 else: 257 finder = lambda transformed: transformed == transformed_top 258 for original, transformed in iteritems(transformed_map): 259 if finder(transformed): 260 return original 261 return None if missing_fn is None else missing_fn(transformed_top) 262 263 def transformed(self, original, missing_fn=None): 264 """Return the transformed op/tensor corresponding to the original one. 265 266 Note that the output of this function mimics the hierarchy 267 of its input argument `original`. 268 Given an iterable, it returns a list. Given an operation or a tensor, 269 it will return an operation or a tensor. 270 271 Args: 272 original: the original tensor/operation. 273 missing_fn: function handling the case where the counterpart 274 cannot be found. By default, None is returned. 275 Returns: 276 the transformed tensor/operation (or None if no match is found). 277 """ 278 transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) 279 return util.transform_tree(original, transformed_elem) 280 281 def original(self, transformed, missing_fn=None): 282 """Return the original op/tensor corresponding to the transformed one. 283 284 Note that the output of this function mimics the hierarchy 285 of its input argument `transformed`. 286 Given an iterable, it returns a list. Given an operation or a tensor, 287 it will return an operation or a tensor. 288 289 Args: 290 transformed: the transformed tensor/operation. 291 missing_fn: function handling the case where the counterpart 292 cannot be found. By default, None is returned. 293 Returns: 294 the original tensor/operation (or None if no match is found). 295 """ 296 original_elem = partial(self._original_elem, missing_fn=missing_fn) 297 return util.transform_tree(transformed, original_elem) 298 299 def __str__(self): 300 res = StringIO() 301 print("Transform result info:", file=res) 302 if self._graph == self._graph_: 303 in_place_str = "" if self._scope_ else " IN-PLACE" 304 print(" Within graph[{}]{}".format( 305 id(self._graph), in_place_str), file=res) 306 else: 307 print(" graph[{}] => graph[{}]".format( 308 id(self._graph), id(self._graph_)), file=res) 309 if self._scope: 310 print(" Relative to source scope: {}".format(self._scope), file=res) 311 if self._scope_: 312 print(" Scope destination: {}".format(self._scope_), file=res) 313 print("Operations mapping:", file=res) 314 for op, op_ in iteritems(self._transformed_ops): 315 print(" {} => {}".format(op.name, op_.name), file=res) 316 return res.getvalue() 317 318 319class _TmpInfo(object): 320 """Transformer temporary data. 321 322 An instance of this class holds all the information relevant to a call 323 to a transformer instance (that is, a call to __call__). An instance 324 is created for the life-time of the __call__ function and is passed as 325 argument to the handlers. 326 """ 327 328 def __init__(self, sgv, dst_graph, dst_scope, src_scope): 329 self.sgv = sgv 330 self.sgv_inputs_set = frozenset(sgv.inputs) 331 self.ops = frozenset(sgv.ops) 332 self.control_outputs = util.ControlOutputs(sgv.graph) 333 self.graph = sgv.graph 334 self.scope = src_scope 335 self.graph_ = dst_graph 336 self.scope_ = dst_scope 337 self.transformed_ops = {} 338 self.transformed_ts = {} 339 self.collections = dict((key, self.graph.get_collection(key)) 340 for key in self.graph.get_all_collection_keys()) 341 self.cyclic_ops = [] 342 self.transform_original_op_handler = transform_op_if_inside_handler 343 # The graph is transformed op by op, in the same order the original ops 344 # were created. However, this is sometimes not possible due to cycles 345 # (i.e. while loops). So when the transformer creates a new op whose 346 # inputs do not exist yet, temporary placeholders are created and stored 347 # in this `tmp_cyclic_ts` container. During a second pass, 348 # those temporary tensors are replaced by the proper transformed tensors 349 # (see the function `_finalize_cycles`). 350 self.tmp_cyclic_ts = [] 351 352 def new_name(self, name): 353 """Compute a destination name from a source name. 354 355 Args: 356 name: the name to be "transformed". 357 Returns: 358 The transformed name. 359 Raises: 360 ValueError: if the source scope is used (that is, not an empty string) 361 and the source name does not belong to the source scope. 362 """ 363 scope = self.scope 364 if not name.startswith(scope): 365 raise ValueError("{} does not belong to source scope: {}.".format( 366 name, scope)) 367 rel_name = name[len(scope):] 368 name_ = self.scope_ + rel_name 369 return name_ 370 371 372class Transformer(object): 373 """Transform a subgraph into another one. 374 375 By default, the constructor create a transform which copy a subgraph and 376 replaces inputs with placeholders. This behavior can be modified by changing 377 the handlers. 378 """ 379 380 def __init__(self): 381 """Transformer constructor. 382 383 The following members can be modified: 384 transform_op_handler: handle the transformation of a `tf.Operation`. 385 This handler defaults to a simple copy. 386 assign_collections_handler: handle the assignment of collections. 387 This handler defaults to assigning new collections created under the 388 given name-scope. 389 transform_external_input_handler: handle the transform of the inputs to 390 the given subgraph. This handler defaults to creating placeholders 391 instead of the ops just before the input tensors of the subgraph. 392 transform_external_hidden_input_handler: handle the transform of the 393 hidden inputs of the subgraph, that is, the inputs which are not listed 394 in sgv.inputs. This handler defaults to a transform which keep the same 395 input if the source and destination graphs are the same, otherwise 396 use placeholders. 397 transform_original_op_handler: handle the transform of original_op. This 398 handler defaults to transforming original_op only if they are in the 399 subgraph, otherwise they are ignored. 400 """ 401 402 # handlers 403 self.transform_op_handler = copy_op_handler 404 self.transform_control_input_handler = transform_op_if_inside_handler 405 self.assign_collections_handler = assign_renamed_collections_handler 406 self.transform_external_input_handler = replace_t_with_placeholder_handler 407 self.transform_external_hidden_input_handler = keep_t_if_possible_handler 408 self.transform_original_op_handler = transform_op_if_inside_handler 409 410 def __call__(self, 411 sgv, 412 dst_graph, 413 dst_scope, 414 src_scope="", 415 reuse_dst_scope=False): 416 """Execute the transformation. 417 418 Args: 419 sgv: the source subgraph-view. 420 dst_graph: the destination graph. 421 dst_scope: the destination scope. 422 src_scope: the source scope, which specify the path from which the 423 relative path of the transformed nodes are computed. For instance, if 424 src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a 425 relative path of x/y and will be transformed into b/x/y. 426 reuse_dst_scope: if True the dst_scope is re-used if it already exists. 427 Otherwise, the scope is given a unique name based on the one given 428 by appending an underscore followed by a digit (default). 429 Returns: 430 A tuple `(sgv, info)` where: 431 `sgv` is the transformed subgraph view; 432 `info` is an instance of TransformerInfo containing 433 information about the transform, including mapping between 434 original and transformed tensors and operations. 435 Raises: 436 ValueError: if the arguments are invalid. 437 """ 438 sgv = subgraph.make_view(sgv) 439 if not isinstance(dst_graph, tf_ops.Graph): 440 raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) 441 442 src_scope = util.scope_finalize(src_scope) 443 dst_scope = util.scope_finalize(dst_scope) 444 445 # Potentially create new scope if reuse_dst_scope is False 446 if dst_scope and not reuse_dst_scope: 447 dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) 448 449 # Create temporary info used during this transform call 450 info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) 451 452 self._copy_ops(info) 453 self._finalize_cycles(info) 454 self._connect_control_inputs(info) 455 456 # Compute information about the transformation 457 res_info = TransformerInfo(info) 458 sgv_ = self._transform_sgv(info, sgv) 459 return sgv_, res_info 460 461 def _copy_ops(self, info): 462 """Copy ops without connecting them.""" 463 sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access 464 for op in sorted_ops: 465 new_inputs = [self._transformed_t(info, t, op) for t in op.inputs] 466 op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) 467 if op is op_: 468 raise ValueError("In-place transformation not allowed.") 469 470 # Process op. 471 info.transformed_ops[op] = op_ 472 self.assign_collections_handler(info, op, op_) 473 474 # Process output tensors. 475 for op_output, op_output_ in zip(op.outputs, op_outputs_): 476 info.transformed_ts[op_output] = op_output_ 477 self.assign_collections_handler(info, op_output, op_output_) 478 479 def _finalize_cycles(self, info): 480 """Reconnects the cyclic tensors.""" 481 for t, tmp_t_, consumer_op in info.tmp_cyclic_ts: 482 if t not in info.transformed_ts: 483 raise ValueError("The tensor {} should be transformed by now.".format( 484 t.name)) 485 if consumer_op not in info.transformed_ops: 486 raise ValueError("The op {} should be transformed by now.".format( 487 consumer_op.name)) 488 t_ = info.transformed_ts[t] 489 consumer_op_ = info.transformed_ops[consumer_op] 490 t_index_ = list(consumer_op_.inputs).index(tmp_t_) 491 consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access 492 493 def _connect_control_inputs(self, info): 494 """Connect the previously copied ops.""" 495 for op in info.sgv.ops: 496 logging.debug("Connecting control inputs of op: %s", op.name) 497 op_ = info.transformed_ops[op] 498 499 # Finalize original op. 500 # TODO(fkp): Stop worrying about _original_op and remove this code? 501 # pylint: disable=protected-access 502 if op._original_op: 503 original_op = self.transform_original_op_handler(info, op._original_op) 504 if original_op is None: 505 logging.debug("Could not find original op for: %s", op_.name) 506 else: 507 op_._original_op = original_op 508 # pylint: enable=protected-access 509 510 # Finalize control inputs: 511 control_inputs_ = [self.transform_control_input_handler(info, ci) 512 for ci in op.control_inputs] 513 control_inputs_ = [ci for ci in control_inputs_ if ci is not None] 514 reroute.add_control_inputs(op_, control_inputs_) 515 516 def _transform_sgv(self, info, sgv): 517 """Transform a subgraph view. 518 519 For convenience, a transform operation returns a subgraph view of the 520 transformed graph. 521 522 Args: 523 info: Temporary information for this transorfm call. 524 sgv: the subgraph to be transformed. 525 Returns: 526 The transformed subgraph. 527 """ 528 ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)] 529 sgv_ = subgraph.SubGraphView(ops_) 530 sgv_inputs_ = sgv_.inputs 531 sgv_outputs_ = sgv_.outputs 532 533 # re-order inputs 534 input_map_ = [] 535 for input_t in sgv.inputs: 536 if input_t not in info.transformed_ts: 537 continue 538 input_t_ = info.transformed_ts[input_t] 539 if input_t_ not in sgv_inputs_: 540 continue 541 input_t_index_ = sgv_.input_index(input_t_) 542 input_map_.append(input_t_index_) 543 544 # re-order outputs 545 output_map_ = [] 546 for output_t in sgv.outputs: 547 if output_t not in info.transformed_ts: 548 continue 549 output_t_ = info.transformed_ts[output_t] 550 if output_t_ not in sgv_outputs_: 551 continue 552 output_t_index_ = sgv_.output_index(output_t_) 553 output_map_.append(output_t_index_) 554 555 return sgv_.remap(input_map_, output_map_) 556 557 def _transformed_t(self, info, t, consumer_op): 558 """Return tre transformed tensor of `t`.""" 559 if t in info.transformed_ts: 560 # If op is in the subgraph, just return its transformed counterpart. 561 return info.transformed_ts[t] 562 563 if t in info.sgv_inputs_set: 564 # `t` is an input of the subgraph. 565 return self.transform_external_input_handler(info, t) 566 elif t.op in info.ops: 567 # `t` is an internal tensor but is not transformed yet because it 568 # belongs to a graph cycle. 569 logging.debug("Cyclic tensor: t.name = %s", t.name) 570 # Try to find an existing tensor we can use for now, 571 # otherwise create one. We'll rewire this later. 572 if consumer_op.type == "Merge": 573 first_input = consumer_op.inputs[0] 574 tmp_t_ = self._transformed_t(info, first_input, consumer_op) 575 elif t.op.type == "Enter": 576 enter_input = t.op.inputs[0] 577 tmp_t_ = self._transformed_t(info, enter_input, consumer_op) 578 else: 579 with info.graph_.as_default(): 580 tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, 581 prefix="geph_tmp") 582 logging.debug("Created temporary placeholder: %s.", tmp_t_.name) 583 # Register as temporary and return. 584 info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) 585 return tmp_t_ 586 else: 587 # `t` is a hidden input of the subgraph. 588 return self.transform_external_hidden_input_handler(info, t) 589 590 591def copy(sgv, dst_graph=None, dst_scope="", src_scope="", 592 reuse_dst_scope=False): 593 """Copy a subgraph. 594 595 Args: 596 sgv: the source subgraph-view. This argument is converted to a subgraph 597 using the same rules than the function subgraph.make_view. 598 dst_graph: the destination graph. 599 dst_scope: the destination scope. 600 src_scope: the source scope. 601 reuse_dst_scope: if True the dst_scope is re-used if it already exists. 602 Otherwise, the scope is given a unique name based on the one given 603 by appending an underscore followed by a digit (default). 604 Returns: 605 A tuple `(sgv, info)` where: 606 `sgv` is the transformed subgraph view; 607 `info` is an instance of TransformerInfo containing 608 information about the transform, including mapping between 609 original and transformed tensors and operations. 610 Raises: 611 TypeError: if `dst_graph` is not a `tf.Graph`. 612 StandardError: if sgv cannot be converted to a SubGraphView using 613 the same rules than the function subgraph.make_view. 614 """ 615 sgv = subgraph.make_view(sgv) 616 if dst_graph is None: 617 dst_graph = sgv.graph 618 if not isinstance(dst_graph, tf_ops.Graph): 619 raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) 620 621 copier = Transformer() 622 return copier( 623 sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) 624 625 626def copy_with_input_replacements(sgv, replacement_ts, 627 dst_graph=None, dst_scope="", src_scope="", 628 reuse_dst_scope=False): 629 """Copy a subgraph, replacing some of its inputs. 630 631 Note a replacement only happens if the tensor to be replaced 632 is an input of the given subgraph. The inputs of a subgraph can 633 be queried using sgv.inputs. 634 635 Args: 636 sgv: the source subgraph-view. This argument is converted to a subgraph 637 using the same rules as the function subgraph.make_view. 638 replacement_ts: dictionary mapping from original tensors to the 639 replaced one. 640 dst_graph: the destination graph. 641 dst_scope: the destination scope. 642 src_scope: the source scope. 643 reuse_dst_scope: if True the dst_scope is re-used if it already exists. 644 Otherwise, the scope is given a unique name based on the one given 645 by appending an underscore followed by a digit (default). 646 Returns: 647 A tuple `(sgv, info)` where: 648 `sgv` is the transformed subgraph view; 649 `info` is an instance of TransformerInfo containing 650 information about the transform, including mapping between 651 original and transformed tensors and operations. 652 Raises: 653 TypeError: if dst_graph is not a tf.Graph. 654 StandardError: if sgv cannot be converted to a SubGraphView using 655 the same rules as the function subgraph.make_view. 656 """ 657 sgv = subgraph.make_view(sgv) 658 if dst_graph is None: 659 dst_graph = sgv.graph 660 if not isinstance(dst_graph, tf_ops.Graph): 661 raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) 662 663 copier = Transformer() 664 # Replace tensor if possible. 665 def replace_t_with_replacement_handler(info, t): 666 if t in replacement_ts: 667 return replacement_ts[t] 668 else: 669 return keep_t_if_possible_handler(info, t) 670 copier.transform_external_input_handler = replace_t_with_replacement_handler 671 return copier( 672 sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) 673 674 675def _add_control_flow_ops(ops, control_ios): 676 """Complete `ops` so that the transformed graph is valid. 677 678 Partially copying a graph can lead to a malformed graph. For instance, 679 copying half of a while construct is likely to result in an invalid graph. 680 This function attempts to add missing ops so that the transformation result 681 in a valid graph. 682 683 Args: 684 ops: list of ops (modifed in-place). 685 control_ios: object created by a call to `util.ControlOutputs`. 686 """ 687 # Find while contexts. 688 control_flow_contexts = set() 689 for op in ops: 690 cfc = op._control_flow_context # pylint: disable=protected-access 691 if cfc: 692 control_flow_contexts.add(cfc) 693 # Find new ops. 694 new_ops = [] 695 for cfc in control_flow_contexts: 696 if cfc.IsWhileContext(): 697 new_ops += select.get_walks_intersection_ops( 698 [enter_t.op for enter_t in cfc.loop_enters], 699 [exit_t.op for exit_t in cfc.loop_exits], 700 control_ios=control_ios) 701 # Add new ops. 702 new_ops_set = set(new_ops) 703 ops_set = frozenset(ops) 704 for op in new_ops_set: 705 if op not in ops_set: 706 ops.append(op) 707 708 709def graph_replace(target_ts, replacement_ts, dst_scope="", 710 src_scope="", reuse_dst_scope=False): 711 """Create a new graph which compute the targets from the replaced Tensors. 712 713 Args: 714 target_ts: a single tf.Tensor or an iterable of tf.Tensor. 715 replacement_ts: dictionary mapping from original tensors to replaced tensors 716 dst_scope: the destination scope. 717 src_scope: the source scope. 718 reuse_dst_scope: if True the dst_scope is re-used if it already exists. 719 Otherwise, the scope is given a unique name based on the one given 720 by appending an underscore followed by a digit (default). 721 Returns: 722 A single tf.Tensor or a list of target tf.Tensor, depending on 723 the type of the input argument `target_ts`. 724 The returned tensors are recomputed using the tensors from replacement_ts. 725 Raises: 726 ValueError: if the targets are not connected to replacement_ts. 727 """ 728 # Identify operations in the graph that will change. 729 # Start forward walk at Tensors that will be replaced, and 730 # backward walk at the target output Tensors. 731 flatten_target_ts = util.flatten_tree(target_ts) 732 # Construct the forward control dependencies edges so that 733 # the get_walks_intersection_ops can also traverse the 734 # control dependencies. 735 graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor)) 736 control_ios = util.ControlOutputs(graph) 737 ops = select.get_walks_intersection_ops( 738 list(replacement_ts), flatten_target_ts, control_ios=control_ios) 739 if not ops: 740 raise ValueError("Targets and replacements are not connected!") 741 742 # Complete ops to avoid malformed control flow. 743 # TODO(fkp): Consider moving this function deeper (in the transformer?). 744 _add_control_flow_ops(ops, control_ios) 745 746 # Create a copy of the relevant subgraph 747 unused_sgv_, info = copy_with_input_replacements( 748 ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) 749 # Return the transformed targets but keep the original if the transformed 750 # counterpart cannot be found 751 missing_fn = lambda original_t: original_t 752 return info.transformed(target_ts, missing_fn) 753