1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Utility to lift subgraphs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import six 24 25from tensorflow.python.framework import func_graph 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import resource_variable_ops 29 30 31def _graph_inputs(op): 32 return [x.op for x in op.inputs] + list(op.control_inputs) 33 34 35def _as_operation(op_or_tensor): 36 if isinstance(op_or_tensor, ops.Tensor): 37 return op_or_tensor.op 38 return op_or_tensor 39 40 41class UnliftableError(Exception): 42 """Raised if a Tensor cannot be lifted from the graph.""" 43 pass 44 45 46def _constant_inputs(op_or_tensor): 47 return all(_as_operation(i).type == u"Const" 48 and not _as_operation(i).control_inputs 49 for i in _graph_inputs(_as_operation(op_or_tensor))) 50 51 52def _path_from(from_op, tensor, sources): 53 """Find one path from `from_op` to `tensor`, ignoring `sources`. 54 55 Args: 56 from_op: A `tf.Operation`. 57 tensor: A `tf.Operation` or `tf.Tensor`. 58 sources: A list of `tf.Tensor`. 59 60 Returns: 61 A python string containing the path, or "??" if none is found. 62 """ 63 visited_ops = set([x.op for x in sources]) 64 ops_to_visit = [_as_operation(tensor)] 65 some_op_output = {} 66 while ops_to_visit: 67 op = ops_to_visit.pop() 68 if op in visited_ops: 69 continue 70 visited_ops.add(op) 71 if op == from_op: 72 path_op = op 73 path = [path_op] 74 final_op = _as_operation(tensor) 75 while path_op != final_op: 76 path_op = some_op_output[path_op] 77 path.append(path_op) 78 return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)]) 79 else: 80 for inp in _graph_inputs(op): 81 if inp not in visited_ops and inp not in sources: 82 some_op_output[inp] = op 83 ops_to_visit.append(inp) 84 return "??" 85 86 87def _map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, 88 op_outputs, add_sources): 89 """Walk a Graph and capture the subgraph between init_tensor and sources. 90 91 Note: This function mutates visited_ops and op_outputs. 92 93 Arguments: 94 init_tensor: A Tensor or Operation where the subgraph terminates. 95 sources: A set of Tensors where subgraph extraction should stop. 96 disallowed_placeholders: An optional set of ops which may not appear in the 97 lifted graph. Defaults to all placeholders. 98 visited_ops: A set of operations which were visited in a prior pass. 99 op_outputs: A defaultdict containing the outputs of an op which are to be 100 copied into the new subgraph. 101 add_sources: A boolean indicating whether placeholders which are not in 102 sources should be allowed. 103 104 Returns: 105 The set of placeholders upon which init_tensor depends and are not in 106 sources. 107 108 Raises: 109 UnliftableError: if init_tensor depends on a placeholder which is not in 110 sources and add_sources is False. 111 """ 112 ops_to_visit = [_as_operation(init_tensor)] 113 extra_sources = set() 114 while ops_to_visit: 115 op = ops_to_visit.pop() 116 if op in visited_ops: 117 continue 118 visited_ops.add(op) 119 120 should_raise = False 121 if disallowed_placeholders is not None and op in disallowed_placeholders: 122 should_raise = True 123 elif op.type == "Placeholder": 124 if disallowed_placeholders is None and not add_sources: 125 should_raise = True 126 extra_sources.update(op.outputs) 127 128 if should_raise: 129 raise UnliftableError( 130 "Unable to lift tensor %s because it depends transitively on " 131 "placeholder %s via at least one path, e.g.: %s" 132 % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources))) 133 for inp in _graph_inputs(op): 134 op_outputs[inp].add(op) 135 if inp not in visited_ops and inp not in (sources or extra_sources): 136 ops_to_visit.append(inp) 137 138 return extra_sources 139 140 141def _copy_non_source(op, graph, op_map): 142 """Copy an op directly to a given graph. 143 144 This function assumes that all of the inputs to an op have already been 145 copied. 146 147 Args: 148 op: The op to be copied. 149 graph: The destination graph. 150 op_map: A dict mapping ops and tensors in the old graph to the new one. 151 """ 152 copied_inputs = [op_map[x] for x in op.inputs] 153 copied_control_inputs = [op_map[x] for x in op.control_inputs] 154 with ops.control_dependencies(copied_control_inputs), ops.device(op.device): 155 copied_op = graph.create_op( 156 op_type=op.type, 157 inputs=copied_inputs, 158 dtypes=[x.dtype for x in op.outputs], 159 attrs=op.node_def.attr, 160 name=op.name) 161 op_map[op] = copied_op 162 for i, o in enumerate(op.outputs): 163 op_map[o] = copied_op.outputs[i] 164 165 166def _copy_source(s, graph, op_map, handle_captures, inverse_captures): 167 """Create a source in a graph based on a Tensor from a different graph. 168 169 This function creates a placeholder analog of `s` in a graph with the 170 following behavior: 171 172 1) If s is a captured Tensor or Variable and handle_captures is set to True, 173 simply capture it in the new graph as well. 174 175 2) If s is a PlaceholderWithDefault whose default is a constant, preserve 176 said default in the new graph. 177 178 3) When applicable, copy resource variable metadata from `s` to the newly 179 created placeholder. 180 181 Args: 182 s: The source of interest. 183 graph: The destination graph. 184 op_map: A dict mapping ops and tensors in the old graph to the new one. 185 handle_captures: A boolean indicating whether to re-capture s in the new 186 graph or simply create a vanilla placeholder. 187 inverse_captures: A dict mapping s back to the Tensor or Variable that it 188 captures. 189 """ 190 if handle_captures and s in inverse_captures: 191 copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) 192 elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): 193 # Copy the default value to the graph. 194 default_value = s.op.inputs[0] 195 _copy_non_source(op=default_value.op, graph=graph, op_map=op_map) 196 197 with ops.device(s.op.device): 198 copied_placeholder = array_ops.placeholder_with_default( 199 input=op_map[default_value], shape=s.shape, name=s.op.name) 200 else: 201 with ops.device(s.op.device): 202 copied_placeholder = array_ops.placeholder( 203 dtype=s.dtype, shape=s.shape, name=s.op.name) 204 205 base_handle = resource_variable_ops.get_resource_handle_data(s) 206 if base_handle.shape_and_type: 207 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 208 copied_placeholder, 209 base_handle, 210 graph_mode=True) 211 212 op_map[s] = copied_placeholder 213 214 215def lift_to_graph(init_tensors, graph, sources=None, 216 disallowed_placeholders=None, add_sources=False, 217 handle_captures=False, base_graph=None): 218 """Copies the tensor and all its inputs recursively to the outer graph. 219 220 Args: 221 init_tensors: The Tensor to lift. 222 graph: The graph to lift to. 223 sources: Optional sequence of nodes to start from. If omitted the whole 224 subgraph which feeds into `init_tensor` is lifted. 225 disallowed_placeholders: An optional set of ops which may not appear in the 226 lifted graph. Defaults to all placeholders. 227 add_sources: A boolean indicating whether placeholders which are not in 228 sources should be allowed. 229 handle_captures: A boolean indicating whether to re-capture s in the new 230 graph or simply create a vanilla placeholder. 231 base_graph: The graph from which to lift ops. This will be inferred if not 232 specified. 233 234 Returns: 235 A mapping from ops in the current default graph to ops in `graph`. 236 237 Raises: 238 UnliftableError: If a placeholder blocks lifting. 239 """ 240 variable_init_tensors = {i for i in init_tensors if isinstance( 241 i, resource_variable_ops.ResourceVariable)} 242 init_tensors = set(init_tensors).difference(variable_init_tensors) 243 base_graph = base_graph or list(init_tensors)[0].graph 244 245 # Check that the initializer does not depend on any placeholders. 246 sources = set(sources or []) 247 visited_ops = set([x.op for x in sources]) 248 op_outputs = collections.defaultdict(set) 249 250 # First we extract the subgraph between init_tensors and sources. 251 for init_tensor in init_tensors: 252 sources.update(_map_subgraph( 253 init_tensor=init_tensor, 254 sources=sources, 255 disallowed_placeholders=disallowed_placeholders, 256 visited_ops=visited_ops, 257 op_outputs=op_outputs, 258 add_sources=add_sources)) 259 260 # Topologically sort the nodes we've extracted. Now we know how many of their 261 # outputs are part of this subgraph. 262 ops_to_copy = [] 263 marked_ops = set([]) 264 ops_to_visit = [_as_operation(t) for t in init_tensors 265 if not op_outputs[_as_operation(t)]] 266 while ops_to_visit: 267 op = ops_to_visit.pop() 268 if op in marked_ops: 269 continue 270 marked_ops.add(op) 271 ops_to_copy.append(op) 272 for inp in _graph_inputs(op): 273 if (all(x in marked_ops for x in op_outputs[inp]) and 274 inp not in sources): 275 ops_to_visit.append(inp) 276 277 # When lifting from one FuncGraph to another, we will need to capture the 278 # relevant tensors as well. 279 captures = collections.OrderedDict() 280 if (isinstance(base_graph, func_graph.FuncGraph) and 281 isinstance(graph, func_graph.FuncGraph)): 282 captures = base_graph.captures 283 inverse_captures = {v: k for k, v in captures.items()} 284 285 # ops_to_copy now holds a reverse topologically sorted list of ops which 286 # ends in the initializer. We copy those to the outermost graph and 287 # build the initialization op there. 288 with graph.as_default(): 289 op_map = {i: i for i in variable_init_tensors} # Pass through variables. 290 source_ops = set() 291 # Add the sources in the same order as the original graph. 292 for s in six.itervalues(captures): 293 if s in sources: 294 sources.remove(s) 295 source_ops.add(s.op) 296 _copy_source( 297 s=s, 298 graph=graph, 299 op_map=op_map, 300 handle_captures=handle_captures, 301 inverse_captures=inverse_captures) 302 for s in sources: 303 source_ops.add(s.op) 304 _copy_source( 305 s=s, 306 graph=graph, 307 op_map=op_map, 308 handle_captures=handle_captures, 309 inverse_captures=inverse_captures) 310 311 for op in reversed(ops_to_copy): 312 if op in source_ops: 313 continue 314 315 _copy_non_source(op=op, graph=graph, op_map=op_map) 316 317 return op_map 318