• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various function for graph editing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.graph_editor import reroute
22from tensorflow.contrib.graph_editor import select
23from tensorflow.contrib.graph_editor import subgraph
24from tensorflow.contrib.graph_editor import util
25from tensorflow.python.ops import array_ops as tf_array_ops
26
27__all__ = [
28    "detach_control_inputs",
29    "detach_control_outputs",
30    "detach_inputs",
31    "detach_outputs",
32    "detach",
33    "connect",
34    "bypass",
35]
36
37
38def detach_control_inputs(sgv):
39  """Detach all the external control inputs of the subgraph sgv.
40
41  Args:
42    sgv: the subgraph view to be detached. This argument is converted to a
43      subgraph using the same rules as the function subgraph.make_view.
44  """
45  sgv = subgraph.make_view(sgv)
46  for op in sgv.ops:
47    cops = [cop for cop in op.control_inputs if cop not in sgv.ops]
48    reroute.remove_control_inputs(op, cops)
49
50
51def detach_control_outputs(sgv, control_outputs):
52  """Detach all the external control outputs of the subgraph sgv.
53
54  Args:
55    sgv: the subgraph view to be detached. This argument is converted to a
56      subgraph using the same rules as the function subgraph.make_view.
57    control_outputs: a util.ControlOutputs instance.
58  """
59  if not isinstance(control_outputs, util.ControlOutputs):
60    raise TypeError("Expected a util.ControlOutputs, got: {}",
61                    type(control_outputs))
62  control_outputs.update()
63  sgv = subgraph.make_view(sgv)
64  for op in sgv.ops:
65    for cop in control_outputs.get(op):
66      if cop not in sgv.ops:
67        reroute.remove_control_inputs(cop, op)
68
69
70def detach_inputs(sgv, control_inputs=False):
71  """Detach the inputs of a subgraph view.
72
73  Args:
74    sgv: the subgraph view to be detached. This argument is converted to a
75      subgraph using the same rules as the function subgraph.make_view.
76      Note that sgv is modified in place.
77    control_inputs: if True control_inputs are also detached.
78  Returns:
79    A tuple `(sgv, input_placeholders)` where
80      `sgv` is a new subgraph view of the detached subgraph;
81      `input_placeholders` is a list of the created input placeholders.
82  Raises:
83    StandardError: if sgv cannot be converted to a SubGraphView using
84      the same rules than the function subgraph.make_view.
85  """
86  sgv = subgraph.make_view(sgv)
87
88  with sgv.graph.as_default():
89    input_placeholders = [
90        tf_array_ops.placeholder(
91            dtype=input_t.dtype, name=util.placeholder_name(input_t))
92        for input_t in sgv.inputs
93    ]
94
95  reroute.swap_inputs(sgv, input_placeholders)
96  if control_inputs:
97    detach_control_inputs(sgv)
98  return sgv, input_placeholders
99
100
101def detach_outputs(sgv, control_outputs=None):
102  """Detach the output of a subgraph view.
103
104  Args:
105    sgv: the subgraph view to be detached. This argument is converted to a
106      subgraph using the same rules as the function subgraph.make_view.
107      Note that sgv is modified in place.
108    control_outputs: a util.ControlOutputs instance or None. If not None the
109      control outputs are also detached.
110  Returns:
111    A tuple `(sgv, output_placeholders)` where
112      `sgv` is a new subgraph view of the detached subgraph;
113      `output_placeholders` is a list of the created output placeholders.
114  Raises:
115    StandardError: if sgv cannot be converted to a SubGraphView using
116      the same rules than the function subgraph.make_view.
117  """
118  sgv = subgraph.make_view(sgv)
119  # only select outputs with consumers
120  sgv_ = sgv.remap_outputs([output_id
121                            for output_id, output_t in enumerate(sgv.outputs)
122                            if output_t.consumers()])
123  # create consumer subgraph and remap
124  consumers_sgv = subgraph.SubGraphView(sgv_.consumers())
125  consumers_sgv = consumers_sgv.remap_inputs(
126      [input_id for input_id, input_t in enumerate(consumers_sgv.inputs)
127       if input_t in sgv_.outputs])
128
129  with sgv_.graph.as_default():
130    output_placeholders = [
131        util.make_placeholder_from_tensor(input_t)
132        for input_t in consumers_sgv.inputs
133    ]
134
135  reroute.swap_outputs(sgv_, output_placeholders)
136  if control_outputs is not None:
137    detach_control_outputs(sgv_, control_outputs)
138  return sgv_, output_placeholders
139
140
141def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None):
142  """Detach both the inputs and the outputs of a subgraph view.
143
144  Args:
145    sgv: the subgraph view to be detached. This argument is converted to a
146      subgraph using the same rules as the function subgraph.make_view.
147      Note that sgv is modified in place.
148    control_inputs: A boolean indicating whether control inputs are enabled.
149    control_outputs: An instance of util.ControlOutputs or None. If not None,
150      control outputs are enabled.
151    control_ios:  An instance of util.ControlOutputs or None. If not None, both
152      control inputs and control outputs are enabled. This is equivalent to set
153      control_inputs to True and control_outputs to the util.ControlOutputs
154      instance.
155  Returns:
156    A tuple `(sgv, detached_inputs, detached_outputs)` where:
157    `sgv` is a new subgraph view of the detached subgraph;
158    `detach_inputs` is a list of the created input placeholders;
159    `detach_outputs` is a list of the created output placeholders.
160  Raises:
161    StandardError: if sgv cannot be converted to a SubGraphView using
162      the same rules than the function subgraph.make_view.
163  """
164  control_inputs, control_outputs = select.check_cios(control_inputs,
165                                                      control_outputs,
166                                                      control_ios)
167  _, detached_inputs = detach_inputs(sgv, control_inputs)
168  _, detached_outputs = detach_outputs(sgv, control_outputs)
169  return sgv, detached_inputs, detached_outputs
170
171
172def connect(sgv0, sgv1, disconnect_first=False):
173  """Connect the outputs of sgv0 to the inputs of sgv1.
174
175  Args:
176    sgv0: the first subgraph to have its outputs swapped. This argument is
177      converted to a subgraph using the same rules as the function
178      subgraph.make_view.
179      Note that sgv0 is modified in place.
180    sgv1: the second subgraph to have its outputs swapped. This argument is
181      converted to a subgraph using the same rules as the function
182      subgraph.make_view.
183      Note that sgv1 is modified in place.
184    disconnect_first: if True the current outputs of sgv0 are disconnected.
185  Returns:
186    A tuple `(sgv0, sgv1)` of the now connected subgraphs.
187  Raises:
188    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
189      the same rules than the function subgraph.make_view.
190  """
191  sgv0 = subgraph.make_view(sgv0)
192  sgv1 = subgraph.make_view(sgv1)
193  util.check_graphs(sgv0, sgv1)
194  if disconnect_first:
195    detach_outputs(sgv0)
196  sgv0_outputs = subgraph.SubGraphView(passthrough_ts=sgv0.outputs)
197  reroute.reroute_inputs(sgv0_outputs, sgv1)
198  return sgv0, sgv1
199
200
201def bypass(sgv):
202  """Bypass the given subgraph by connecting its inputs to its outputs.
203
204  Args:
205    sgv: the subgraph view to be bypassed. This argument is converted to a
206      subgraph using the same rules than the function subgraph.make_view.
207      Note that sgv is modified in place.
208  Returns:
209    A tuple `(sgv, detached_inputs)` where:
210      `sgv` is a new subgraph view of the bypassed subgraph;
211      `detached_inputs` is a list of the created input placeholders.
212  Raises:
213    StandardError: if sgv cannot be converted to a SubGraphView using
214      the same rules than the function subgraph.make_view.
215  """
216  # TODO(fkp): allows to plug sgv.inputs to individual sgv.outputs consumers
217  sgv = subgraph.make_view(sgv)
218  sgv_inputs = list(sgv.inputs)
219  sgv, detached_inputs = detach_inputs(sgv)
220  reroute.reroute_ts(sgv_inputs, sgv.outputs)
221  return sgv, detached_inputs
222