1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Utility to lift subgraphs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23 24from tensorflow.python.framework import func_graph 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import op_selector 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.util import compat 30from tensorflow.python.util import object_identity 31from tensorflow.python.util.tf_export import tf_export 32 33 34UnliftableError = op_selector.UnliftableError 35 36 37def _as_operation(op_or_tensor): 38 if isinstance(op_or_tensor, ops.Tensor): 39 return op_or_tensor.op 40 return op_or_tensor 41 42 43def _constant_inputs(op_or_tensor): 44 return all(_as_operation(i).type == u"Const" 45 and not _as_operation(i).control_inputs 46 for i in op_selector.graph_inputs(_as_operation(op_or_tensor))) 47 48 49# Represents an input to `copied_op` which must be updated once 50# `old_graph_tensor` has been copied. 51_InputMutation = collections.namedtuple( 52 "_InputMutation", 53 ["copied_op", "input_index", "old_graph_tensor"]) 54 55 56# Represents a control input to `copied_op` which must be added once 57# `old_graph_op` has been copied. 58_ControlMutation = collections.namedtuple( 59 "_ControlMutation", 60 ["copied_op", "old_graph_op"]) 61 62 63def _copy_non_source(op, graph, op_map, base_graph): 64 """Copy an op directly to a given graph. 65 66 Generally `op`'s inputs should already have been copied. If this is not the 67 case, for example with v1 while_loops, then `_copy_non_source` inserts 68 placeholders for the unavailable Tensors and returns a list of required 69 mutations. 70 71 Args: 72 op: The op to be copied. 73 graph: The destination graph. 74 op_map: A dict mapping ops and tensors in the old graph to the new one. 75 base_graph: The graph we're copying from, for any necessary functions. 76 Returns: 77 A tuple of (required_inputs, required_control_inputs): 78 required_inputs: 79 A list of `_InputMutation` tuples containing inputs to `copied_op` which 80 must be updated once `old_graph_tensor` has been copied. 81 required_control_inputs: 82 A list of `_ControlMutation` tuples containing control inputs to 83 `copied_op` which must be added once `old_graph_op` has been copied. 84 """ 85 input_mutations = [] 86 control_mutations = [] 87 copied_inputs = [] 88 for input_index, original_input in enumerate(op.inputs): 89 copied_input = op_map.get(original_input, None) 90 if copied_input is None: 91 # An input for this op is missing due to a loop in the graph. We'll insert 92 # a placeholder for now and return information about the required post-hoc 93 # mutation. 94 copied_input = array_ops.placeholder( 95 name="unused_control_flow_input", 96 shape=original_input.shape, 97 dtype=original_input.dtype) 98 input_mutations.append( 99 # `copied_op` is filled in below, after we've created it. 100 _InputMutation(copied_op=None, 101 input_index=input_index, 102 old_graph_tensor=original_input)) 103 copied_inputs.append(copied_input) 104 105 copied_control_inputs = [] 106 for original_control_input in op.control_inputs: 107 copied_control_input = op_map.get(original_control_input, None) 108 if copied_control_input is None: 109 control_mutations.append( 110 _ControlMutation(copied_op=None, 111 old_graph_op=original_control_input)) 112 else: 113 copied_control_inputs.append(copied_control_input) 114 115 # Don't copy over nodes with _tpu_replicate attribute. This attributed is used 116 # to signal that the op was built inside a tpu_replicate context; if we're 117 # lifting it to another graph we're similarly lifting it into another context. 118 with ops.control_dependencies(copied_control_inputs), ops.device(op.device): 119 # pylint: disable=protected-access 120 f = base_graph._functions.get(op.type, None) 121 if f is not None and compat.as_str(f.name) not in graph._functions: 122 f.add_to_graph(graph) 123 # pylint: enable=protected-access 124 125 # Create a new op in the destination graph if it doesn't exist before. 126 copied_op = graph.create_op( 127 op_type=op.type, 128 inputs=copied_inputs, 129 dtypes=[x.dtype for x in op.outputs], 130 attrs={ 131 key: value for key, value in op.node_def.attr.items() 132 if not key.startswith("_class") and 133 not key.startswith("_tpu_replicate") 134 }, # b/128981532. 135 name=op.name) 136 op_map[op] = copied_op 137 for i, o in enumerate(op.outputs): 138 op_map[o] = copied_op.outputs[i] 139 140 return ([mutation._replace(copied_op=copied_op) 141 for mutation in input_mutations], 142 [mutation._replace(copied_op=copied_op) 143 for mutation in control_mutations]) 144 145 146def _copy_source(s, graph, op_map, handle_captures, inverse_captures, 147 base_graph): 148 """Create a source in a graph based on a Tensor from a different graph. 149 150 This function creates a placeholder analog of `s` in a graph with the 151 following behavior: 152 153 1) If s is a captured Tensor or Variable and handle_captures is set to True, 154 simply capture it in the new graph as well. 155 156 2) If s is a PlaceholderWithDefault whose default is a constant, preserve 157 said default in the new graph. 158 159 3) When applicable, copy resource variable metadata from `s` to the newly 160 created placeholder. 161 162 Args: 163 s: The source of interest. 164 graph: The destination graph. 165 op_map: A dict mapping ops and tensors in the old graph to the new one. 166 handle_captures: A boolean indicating whether to re-capture s in the new 167 graph or simply create a vanilla placeholder. 168 inverse_captures: A dict mapping s back to the Tensor or Variable that it 169 captures. 170 base_graph: The graph being copied from. 171 """ 172 if handle_captures and s in inverse_captures: 173 copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) 174 elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): 175 # Copy the default value to the graph. 176 default_value = s.op.inputs[0] 177 unavailable_inputs, unavailable_control_inputs = _copy_non_source( 178 op=default_value.op, graph=graph, op_map=op_map, 179 base_graph=base_graph) 180 if unavailable_inputs or unavailable_control_inputs: 181 raise AssertionError( 182 "Could not copy source node {} because it has inputs." 183 .format(default_value)) 184 185 with ops.device(s.op.device): 186 copied_placeholder = array_ops.placeholder_with_default( 187 input=op_map[default_value], shape=s.shape, name=s.op.name) 188 else: 189 with ops.device(s.op.device): 190 copied_placeholder = array_ops.placeholder( 191 dtype=s.dtype, shape=s.shape, name=s.op.name) 192 193 base_handle = resource_variable_ops.get_resource_handle_data(s) 194 if base_handle.shape_and_type: 195 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 196 copied_placeholder, 197 base_handle, 198 graph_mode=True) 199 200 op_map[s] = copied_placeholder 201 # Add an entry for the op of the source tensor so that if there are any nodes 202 # depending on that op via control dependencies it can work correctly. 203 op_map[s.op] = copied_placeholder.op 204 205 206@tf_export("__internal__.lift_to_graph", v1=[]) 207def lift_to_graph(tensors, 208 graph, 209 sources=None, 210 disallowed_placeholders=None, 211 add_sources=False, 212 handle_captures=False, 213 base_graph=None, 214 op_map=None): 215 """Copies the tensor and all its inputs recursively to the outer graph. 216 217 Args: 218 tensors: The Tensors to lift. 219 graph: The graph to lift to. 220 sources: Optional sequence of nodes to start from. If omitted the whole 221 subgraph which feeds into `init_tensor` is lifted. 222 disallowed_placeholders: An optional set of ops which may not appear in the 223 lifted graph. Defaults to all placeholders. 224 add_sources: A boolean indicating whether placeholders which are not in 225 sources should be allowed. 226 handle_captures: A boolean indicating whether to re-capture s in the new 227 graph or simply create a vanilla placeholder. 228 base_graph: The graph from which to lift ops. This will be inferred if not 229 specified. 230 op_map: A map contains all the existing nodes that have been lifted to the 231 destination graph, so they won't be lifted and copied again. 232 233 Returns: 234 A mapping from ops in the current default graph to ops in `graph`. 235 236 Raises: 237 UnliftableError: If a placeholder blocks lifting. 238 """ 239 variable_init_tensors = [] 240 init_tensors = [] 241 for tensor in tensors: 242 if isinstance(tensor, resource_variable_ops.ResourceVariable): 243 variable_init_tensors.append(tensor) 244 else: 245 init_tensors.append(tensor) 246 base_graph = base_graph or init_tensors[0].graph 247 op_map = op_map or object_identity.ObjectIdentityDictionary() 248 249 # Check that the initializer does not depend on any placeholders. 250 sources = object_identity.ObjectIdentitySet(sources or []) 251 visited_ops = set(x.op for x in sources) 252 op_outputs = collections.defaultdict(set) 253 254 # First we extract the subgraph between init_tensors and sources. 255 for init_tensor in init_tensors: 256 sources.update(op_selector.map_subgraph( 257 init_tensor=init_tensor, 258 sources=sources, 259 disallowed_placeholders=disallowed_placeholders, 260 visited_ops=visited_ops, 261 op_outputs=op_outputs, 262 add_sources=add_sources)) 263 264 # Try to topologically sort the nodes we've extracted. Now we know how many of 265 # their outputs are part of this subgraph. 266 ops_to_copy = [] 267 marked_ops = set([]) 268 ops_to_visit = [_as_operation(t) for t in init_tensors 269 if not op_outputs[_as_operation(t)]] 270 unvisited_ops = set(ops_to_visit) 271 while unvisited_ops: 272 while ops_to_visit: 273 op = ops_to_visit.pop() 274 if op in marked_ops: 275 continue 276 marked_ops.add(op) 277 ops_to_copy.append(op) 278 for inp in op_selector.graph_inputs(op): 279 # Don't lift the TPUReplicateMetadata nodes out of the function, because 280 # it has no registered kernels. 281 if inp.type == "TPUReplicateMetadata": 282 continue 283 unvisited_ops.add(inp) 284 if (all(x in marked_ops for x in op_outputs[inp]) and 285 inp not in sources): 286 ops_to_visit.append(inp) 287 unvisited_ops.difference_update(marked_ops) 288 if unvisited_ops: 289 # `unvisited_ops` should only have elements if the graph has a loop. In 290 # this case we want to keep copying and there's no topological ordering; 291 # we'll do ugly post-hoc mutations instead. 292 ops_to_visit.append(next(iter(unvisited_ops))) 293 294 # When lifting from one FuncGraph to another, we will need to capture the 295 # relevant tensors as well. 296 captures = [] 297 inverse_captures = object_identity.ObjectIdentityDictionary() 298 internal_captures = [] 299 if (isinstance(base_graph, func_graph.FuncGraph) and 300 isinstance(graph, func_graph.FuncGraph)): 301 captures = base_graph.captures 302 for external_capture, internal_capture in captures: 303 inverse_captures[internal_capture] = external_capture 304 internal_captures = base_graph.internal_captures 305 306 # ops_to_copy now holds a reverse topologically sorted list of ops which 307 # ends in the initializer. We copy those to the outermost graph and 308 # build the initialization op there. 309 with graph.as_default(): 310 for i in variable_init_tensors: 311 op_map[i] = i 312 source_ops = set() 313 # Add the sources in the same order as the original graph. 314 for s in internal_captures: 315 if s in sources: 316 sources.remove(s) 317 source_ops.add(s.op) 318 _copy_source( 319 s=s, 320 graph=graph, 321 op_map=op_map, 322 handle_captures=handle_captures, 323 inverse_captures=inverse_captures, 324 base_graph=base_graph) 325 for s in sources: 326 source_ops.add(s.op) 327 _copy_source( 328 s=s, 329 graph=graph, 330 op_map=op_map, 331 handle_captures=handle_captures, 332 inverse_captures=inverse_captures, 333 base_graph=base_graph) 334 335 input_mutations = [] 336 control_mutations = [] 337 for op in reversed(ops_to_copy): 338 if op in source_ops or op in op_map: 339 continue 340 new_input_mutations, new_control_mutations = _copy_non_source( 341 op=op, graph=graph, op_map=op_map, base_graph=base_graph) 342 input_mutations.extend(new_input_mutations) 343 control_mutations.extend(new_control_mutations) 344 345 # Mutate the new graph to insert any loops which existed in the source 346 # graph due to v1 while_loops. 347 # 348 # pylint: disable=protected-access 349 with graph._mutation_lock(): 350 for mutation in input_mutations: 351 mutation.copied_op._update_input( 352 mutation.input_index, op_map[mutation.old_graph_tensor]) 353 for mutation in control_mutations: 354 # Don't lift the TPUReplicateMetadata nodes out of the function, because 355 # it has no registered kernels. 356 if mutation.old_graph_op.type == "TPUReplicateMetadata": 357 continue 358 mutation.copied_op._add_control_input(op_map[mutation.old_graph_op]) 359 # pylint: enable=protected-access 360 361 return op_map 362