1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various function for graph editing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.graph_editor import reroute 22from tensorflow.contrib.graph_editor import select 23from tensorflow.contrib.graph_editor import subgraph 24from tensorflow.contrib.graph_editor import util 25from tensorflow.python.ops import array_ops as tf_array_ops 26 27__all__ = [ 28 "detach_control_inputs", 29 "detach_control_outputs", 30 "detach_inputs", 31 "detach_outputs", 32 "detach", 33 "connect", 34 "bypass", 35] 36 37 38def detach_control_inputs(sgv): 39 """Detach all the external control inputs of the subgraph sgv. 40 41 Args: 42 sgv: the subgraph view to be detached. This argument is converted to a 43 subgraph using the same rules as the function subgraph.make_view. 44 """ 45 sgv = subgraph.make_view(sgv) 46 for op in sgv.ops: 47 cops = [cop for cop in op.control_inputs if cop not in sgv.ops] 48 reroute.remove_control_inputs(op, cops) 49 50 51def detach_control_outputs(sgv, control_outputs): 52 """Detach all the external control outputs of the subgraph sgv. 53 54 Args: 55 sgv: the subgraph view to be detached. This argument is converted to a 56 subgraph using the same rules as the function subgraph.make_view. 57 control_outputs: a util.ControlOutputs instance. 58 """ 59 if not isinstance(control_outputs, util.ControlOutputs): 60 raise TypeError("Expected a util.ControlOutputs, got: {}", 61 type(control_outputs)) 62 control_outputs.update() 63 sgv = subgraph.make_view(sgv) 64 for op in sgv.ops: 65 for cop in control_outputs.get(op): 66 if cop not in sgv.ops: 67 reroute.remove_control_inputs(cop, op) 68 69 70def detach_inputs(sgv, control_inputs=False): 71 """Detach the inputs of a subgraph view. 72 73 Args: 74 sgv: the subgraph view to be detached. This argument is converted to a 75 subgraph using the same rules as the function subgraph.make_view. 76 Note that sgv is modified in place. 77 control_inputs: if True control_inputs are also detached. 78 Returns: 79 A tuple `(sgv, input_placeholders)` where 80 `sgv` is a new subgraph view of the detached subgraph; 81 `input_placeholders` is a list of the created input placeholders. 82 Raises: 83 StandardError: if sgv cannot be converted to a SubGraphView using 84 the same rules than the function subgraph.make_view. 85 """ 86 sgv = subgraph.make_view(sgv) 87 88 with sgv.graph.as_default(): 89 input_placeholders = [ 90 tf_array_ops.placeholder( 91 dtype=input_t.dtype, name=util.placeholder_name(input_t)) 92 for input_t in sgv.inputs 93 ] 94 95 reroute.swap_inputs(sgv, input_placeholders) 96 if control_inputs: 97 detach_control_inputs(sgv) 98 return sgv, input_placeholders 99 100 101def detach_outputs(sgv, control_outputs=None): 102 """Detach the output of a subgraph view. 103 104 Args: 105 sgv: the subgraph view to be detached. This argument is converted to a 106 subgraph using the same rules as the function subgraph.make_view. 107 Note that sgv is modified in place. 108 control_outputs: a util.ControlOutputs instance or None. If not None the 109 control outputs are also detached. 110 Returns: 111 A tuple `(sgv, output_placeholders)` where 112 `sgv` is a new subgraph view of the detached subgraph; 113 `output_placeholders` is a list of the created output placeholders. 114 Raises: 115 StandardError: if sgv cannot be converted to a SubGraphView using 116 the same rules than the function subgraph.make_view. 117 """ 118 sgv = subgraph.make_view(sgv) 119 # only select outputs with consumers 120 sgv_ = sgv.remap_outputs([output_id 121 for output_id, output_t in enumerate(sgv.outputs) 122 if output_t.consumers()]) 123 # create consumer subgraph and remap 124 consumers_sgv = subgraph.SubGraphView(sgv_.consumers()) 125 consumers_sgv = consumers_sgv.remap_inputs( 126 [input_id for input_id, input_t in enumerate(consumers_sgv.inputs) 127 if input_t in sgv_.outputs]) 128 129 with sgv_.graph.as_default(): 130 output_placeholders = [ 131 util.make_placeholder_from_tensor(input_t) 132 for input_t in consumers_sgv.inputs 133 ] 134 135 reroute.swap_outputs(sgv_, output_placeholders) 136 if control_outputs is not None: 137 detach_control_outputs(sgv_, control_outputs) 138 return sgv_, output_placeholders 139 140 141def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None): 142 """Detach both the inputs and the outputs of a subgraph view. 143 144 Args: 145 sgv: the subgraph view to be detached. This argument is converted to a 146 subgraph using the same rules as the function subgraph.make_view. 147 Note that sgv is modified in place. 148 control_inputs: A boolean indicating whether control inputs are enabled. 149 control_outputs: An instance of util.ControlOutputs or None. If not None, 150 control outputs are enabled. 151 control_ios: An instance of util.ControlOutputs or None. If not None, both 152 control inputs and control outputs are enabled. This is equivalent to set 153 control_inputs to True and control_outputs to the util.ControlOutputs 154 instance. 155 Returns: 156 A tuple `(sgv, detached_inputs, detached_outputs)` where: 157 `sgv` is a new subgraph view of the detached subgraph; 158 `detach_inputs` is a list of the created input placeholders; 159 `detach_outputs` is a list of the created output placeholders. 160 Raises: 161 StandardError: if sgv cannot be converted to a SubGraphView using 162 the same rules than the function subgraph.make_view. 163 """ 164 control_inputs, control_outputs = select.check_cios(control_inputs, 165 control_outputs, 166 control_ios) 167 _, detached_inputs = detach_inputs(sgv, control_inputs) 168 _, detached_outputs = detach_outputs(sgv, control_outputs) 169 return sgv, detached_inputs, detached_outputs 170 171 172def connect(sgv0, sgv1, disconnect_first=False): 173 """Connect the outputs of sgv0 to the inputs of sgv1. 174 175 Args: 176 sgv0: the first subgraph to have its outputs swapped. This argument is 177 converted to a subgraph using the same rules as the function 178 subgraph.make_view. 179 Note that sgv0 is modified in place. 180 sgv1: the second subgraph to have its outputs swapped. This argument is 181 converted to a subgraph using the same rules as the function 182 subgraph.make_view. 183 Note that sgv1 is modified in place. 184 disconnect_first: if True the current outputs of sgv0 are disconnected. 185 Returns: 186 A tuple `(sgv0, sgv1)` of the now connected subgraphs. 187 Raises: 188 StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 189 the same rules than the function subgraph.make_view. 190 """ 191 sgv0 = subgraph.make_view(sgv0) 192 sgv1 = subgraph.make_view(sgv1) 193 util.check_graphs(sgv0, sgv1) 194 if disconnect_first: 195 detach_outputs(sgv0) 196 sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs) 197 reroute.reroute_inputs(sgv0_outputs, sgv1) 198 return sgv0, sgv1 199 200 201def bypass(sgv): 202 """Bypass the given subgraph by connecting its inputs to its outputs. 203 204 Args: 205 sgv: the subgraph view to be bypassed. This argument is converted to a 206 subgraph using the same rules than the function subgraph.make_view. 207 Note that sgv is modified in place. 208 Returns: 209 A tuple `(sgv, detached_inputs)` where: 210 `sgv` is a new subgraph view of the bypassed subgraph; 211 `detached_inputs` is a list of the created input placeholders. 212 Raises: 213 StandardError: if sgv cannot be converted to a SubGraphView using 214 the same rules than the function subgraph.make_view. 215 """ 216 # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers 217 sgv = subgraph.make_view(sgv) 218 sgv_inputs = list(sgv.inputs) 219 sgv, detached_inputs = detach_inputs(sgv) 220 reroute.reroute_ts(sgv_inputs, sgv.outputs) 221 return sgv, detached_inputs 222