1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Prototype decorator for defining legacy-graph-mode functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import weakref 23 24from tensorflow.python.eager import def_function 25from tensorflow.python.eager import function 26from tensorflow.python.eager import lift_to_graph 27from tensorflow.python.framework import func_graph 28from tensorflow.python.framework import importer 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import resource_variable_ops 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import tf_export 35 36 37class VariableHolder(object): 38 """Holds variables for a python function.""" 39 40 def __init__(self, fn=None, share_variables=False): 41 self._fn = fn 42 43 self._variables = [] 44 45 self._share_variables = share_variables 46 self._variables_by_name = {} 47 48 @property 49 def variables(self): 50 return self._variables 51 52 def variable_creator_scope(self, next_creator, **kwargs): 53 """Creates variables & adds them to collections to match legacy code.""" 54 collections = kwargs.pop("collections", None) 55 v = None 56 57 # Get expected variable name. 58 name = kwargs.get("name", None) 59 with ops.name_scope(name, "Variable") as name_scope: 60 name = name_scope 61 62 if self._share_variables: 63 v = self._variables_by_name.get(name, None) 64 65 if v is None: 66 v = next_creator(**kwargs) 67 self._variables.append(v) 68 if self._share_variables: 69 self._variables_by_name[name] = v 70 71 if collections is None: 72 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 73 if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 74 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 75 76 ops.add_to_collections(collections, v) 77 78 return v 79 80 def __call__(self, *args, **kwargs): 81 return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) 82 83 def call_with_variable_creator_scope(self, fn): 84 def wrapped(*args, **kwargs): 85 with variable_scope.variable_creator_scope(self.variable_creator_scope): 86 return fn(*args, **kwargs) 87 return wrapped 88 89 90# TODO(allenl): make this trackable 91class WrappedFunction(function.ConcreteFunction): 92 """Wraps a tf V1 piece of code in a function.""" 93 94 def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): 95 super(WrappedFunction, self).__init__( 96 fn_graph, attrs=attrs, signature=signature) 97 self._variable_holder = variable_holder 98 if ops.executing_eagerly_outside_functions(): 99 # TODO(allenl): Make this work in 1.x? 100 self._lift_unlifted_variables() 101 102 def _lift_unlifted_variables(self): 103 """Finds resource variables and lifts them into the outer context. 104 105 When we import a GraphDef inside a wrap_function, no Python graph building 106 code runs. This means we get VarHandleOps which create variable resources, 107 but no corresponding Python objects. Leaving them like this works but gives 108 the user no way to interact with or modify the variables outside the graph. 109 110 This method searches for variables and lifts them out as regular variable 111 objects when possible, indicating to the FuncGraph that they are captures. 112 """ 113 with self.graph.as_default(): 114 collection_variables = ( 115 ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 116 + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 117 existing_captures = set(self.graph.internal_captures) 118 lifted_variables = {} 119 for old_variable in collection_variables: 120 if (old_variable._in_graph_mode # pylint: disable=protected-access 121 and isinstance(old_variable, 122 resource_variable_ops.ResourceVariable)): 123 if old_variable.handle in existing_captures: 124 continue 125 new_variable = def_function.UnliftedInitializerVariable( 126 array_ops.placeholder( 127 name="unused_{}_initializer".format(old_variable.op.name), 128 shape=old_variable.shape, 129 dtype=old_variable.dtype), 130 name=old_variable.op.name, 131 trainable=old_variable.trainable) 132 self.graph.captures[new_variable.handle] = old_variable.handle 133 existing_captures.add(old_variable.handle) 134 lifted_variables[old_variable] = new_variable 135 # pylint: disable=protected-access 136 self._variable_holder._variables.append(new_variable) 137 self.graph._weak_variables.append(weakref.ref(new_variable)) 138 # pylint: enable=protected-access 139 # Update the graph's collections, partly for the user and partly so this 140 # function is idempotent when it runs again in prune() calls. 141 for collection_name in [ops.GraphKeys.GLOBAL_VARIABLES, 142 ops.GraphKeys.LOCAL_VARIABLES]: 143 mutable_collection = ops.get_collection_ref(collection_name) 144 for index, current in enumerate(mutable_collection): 145 mutable_collection[index] = lifted_variables.get(current, current) 146 147 def prune(self, feeds, fetches, name=None): 148 name = name or "pruned" 149 flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) 150 for f in flat_feeds: 151 if not isinstance(f, ops.Tensor): 152 raise ValueError("Feeds must be tensors.") 153 154 # Ignoring all feeds that are captures allows prune to be called 155 # using wrapped_func.inputs even when it uses variables 156 internal_captures = self.graph.internal_captures 157 flat_feeds = [f for f in flat_feeds 158 if f not in internal_captures] 159 160 tensor_fetches = [] 161 operation_fetches = [] 162 for f in flat_fetches: 163 if isinstance(f, ops.Tensor): 164 tensor_fetches.append(f) 165 elif isinstance(f, ops.Operation): 166 operation_fetches.append(f) 167 else: 168 raise ValueError("Fetches must be tensors or operations.") 169 for f in flat_feeds + flat_fetches: 170 if f.graph is not self._func_graph: 171 raise ValueError( 172 "Can only prune function whose feeds and fetches " 173 "are from this graph (%s). Tensor %s from graph %s" % ( 174 self._func_graph, f, f.graph)) 175 with self._func_graph.as_default(): 176 pruned_graph = func_graph.FuncGraph(name) 177 with ops.control_dependencies(operation_fetches): 178 if tensor_fetches: 179 identity_fetches = array_ops.identity_n(tensor_fetches) 180 sink_tensor = identity_fetches[0] 181 else: 182 identity_fetches = [] 183 sink_tensor = array_ops.zeros([]) 184 lift_map = lift_to_graph.lift_to_graph( 185 [sink_tensor], pruned_graph, sources=flat_feeds + internal_captures) 186 for original_fetch, identity_fetch in zip( 187 tensor_fetches, identity_fetches): 188 lift_map[original_fetch] = lift_map[identity_fetch] 189 pruned_graph.outputs.extend( 190 lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor)) 191 if not tensor_fetches: 192 pruned_graph.outputs.append(lift_map[sink_tensor]) 193 for external_capture, internal_capture in self.graph.captures.items(): 194 pruned_graph.captures[external_capture] = lift_map[internal_capture] 195 pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) 196 pruned_graph.inputs.extend(pruned_graph.captures.values()) 197 198 pruned_graph.variables = self.graph.variables 199 200 def _structured_output_mapping(fetched): 201 lifted = lift_map[fetched] 202 if isinstance(lifted, ops.Operation): 203 return None 204 return lifted 205 206 pruned_graph.structured_outputs = nest.map_structure( 207 _structured_output_mapping, fetches) 208 pruned_fn = WrappedFunction( 209 pruned_graph, variable_holder=self._variable_holder) 210 pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access 211 pruned_fn._arg_keywords = [] # pylint: disable=protected-access 212 return pruned_fn 213 214 215class WrappedGraph(object): 216 """Class for wrapping multiple TF 1.X functions in a single graph. 217 218 Maintains a dictionary mapping names to wrapped functions. See 219 `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. 220 221 Functions wrapped using this class have access to variables and collections 222 created in other wrapped functions, using the standard TF 1.X API ( 223 `tf.compat.v1.get_variable` or 224 `tf.compat.v1.get_default_graph().get_collection(...)`) 225 226 Outside a function, variables and collections may be accessed using the 227 `variables` and `graph` properties. 228 229 Example: 230 231 ``` 232 def add_v1(x): 233 with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE): 234 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 235 return v + x 236 237 def increment_var_v1(x): 238 with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE): 239 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 240 return v.assign_add(x) 241 242 g = WrappedGraph() 243 add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) 244 increment_var = g.wrap_function(increment_var_v1, 245 [tf.TensorSpec([], tf.int32)]) 246 247 assert len(g.variables) == 1 248 assert g.variables[0].numpy() == 0 249 increment_var(tf.constant(5)) 250 assert g.variables[0].numpy() == 5 251 252 ``` 253 """ 254 255 def __init__(self, variable_holder=None, **kwargs): 256 self._variable_holder = ( 257 variable_holder or VariableHolder(share_variables=True)) 258 259 name = kwargs.pop("name", "wrapped_function_graph") 260 # Always start with empty collections, unless otherwise specified. Setting 261 # `collections=None` will copy the collections from the outer graph. 262 collections = kwargs.pop("collections", {}) 263 self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) 264 265 self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) 266 self._functions = {} 267 268 @property 269 def functions(self): 270 return self._functions 271 272 @property 273 def variables(self): 274 return self._variable_holder.variables 275 276 def wrap_function(self, fn, signature, name=None): 277 """Wrap a TF 1.X function and save to functions dictionary.""" 278 func_graph.func_graph_from_py_func( 279 None, # Name is unused. 280 self._variable_holder.call_with_variable_creator_scope(fn), 281 args=None, kwargs=None, signature=signature, 282 add_control_dependencies=False, 283 func_graph=self.graph) 284 285 # This code relies on questional behavior from `func_graph_from_py_func`. 286 # If an existing FuncGraph is passed into the `func_graph` arg, the inputs 287 # and structured outputs are overwritten. Pretty sure this is a bug, 288 # because structured outputs doesn't match up with the outputs... 289 fn_inputs = self.graph.inputs[:-len(self.graph.captures)] 290 fn_outputs = self.graph.structured_outputs 291 292 wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs) 293 name = name or fn.__name__ 294 self._functions[name] = wrapped_function 295 return wrapped_function 296 297 298@tf_export(v1=["wrap_function"]) 299def wrap_function(fn, signature, name=None): 300 """Wraps the TF 1.x function fn into a graph function. 301 302 The python function `fn` will be called once with symbolic arguments specified 303 in the `signature`, traced, and turned into a graph function. Any variables 304 created by `fn` will be owned by the object returned by `wrap_function`. The 305 resulting graph function can be called with tensors which match the 306 signature. 307 308 ```python 309 def f(x, do_add): 310 v = tf.Variable(5.0) 311 if do_add: 312 op = v.assign_add(x) 313 else: 314 op = v.assign_sub(x) 315 with tf.control_dependencies([op]): 316 return v.read_value() 317 318 f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) 319 320 assert float(f_add(1.0)) == 6.0 321 assert float(f_add(1.0)) == 7.0 322 323 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 324 # of variables, and possibly different non-template arguments. 325 f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) 326 327 assert float(f_sub(1.0)) == 4.0 328 assert float(f_sub(1.0)) == 3.0 329 ``` 330 331 Both `tf.compat.v1.wrap_function` and `tf.function` create a callable 332 TensorFlow graph. But while `tf.function` runs all stateful operations 333 (e.g. `tf.print`) and sequences operations to provide the same semantics as 334 eager execution, `wrap_function` is closer to the behavior of `session.run` in 335 TensorFlow 1.x. It will not run any operations unless they are required to 336 compute the function's outputs, either through a data dependency or a control 337 dependency. Nor will it sequence operations. 338 339 Unlike `tf.function`, `wrap_function` will only trace the Python function 340 once. As with placeholders in TF 1.x, shapes and dtypes must be provided to 341 `wrap_function`'s `signature` argument. 342 343 Since it is only traced once, variables and state may be created inside the 344 function and owned by the function wrapper object. 345 346 Args: 347 fn: python function to be wrapped 348 signature: the placeholder and python arguments to be passed to the 349 wrapped function 350 name: Optional. The name of the function. 351 352 Returns: 353 the wrapped graph function. 354 """ 355 holder = VariableHolder(fn) 356 func_graph_name = "wrapped_function" 357 if name is not None: 358 func_graph_name = "wrapped_function_" + name 359 return WrappedFunction( 360 func_graph.func_graph_from_py_func( 361 func_graph_name, 362 holder, 363 args=None, kwargs=None, signature=signature, 364 add_control_dependencies=False, 365 collections={}), 366 variable_holder=holder, 367 signature=signature) 368 369 370def function_from_graph_def(graph_def, inputs, outputs): 371 """Creates a ConcreteFunction from a GraphDef. 372 373 Args: 374 graph_def: A GraphDef to make a function out of. 375 inputs: A Tensor name or nested structure of names in `graph_def` which 376 should be inputs to the function. 377 outputs: A Tensor name or nested structure of names in `graph_def` which 378 should be outputs of the function. 379 380 Returns: 381 A ConcreteFunction. 382 """ 383 def _imports_graph_def(): 384 importer.import_graph_def(graph_def, name="") 385 386 wrapped_import = wrap_function(_imports_graph_def, []) 387 import_graph = wrapped_import.graph 388 return wrapped_import.prune( 389 nest.map_structure(import_graph.as_graph_element, inputs), 390 nest.map_structure(import_graph.as_graph_element, outputs)) 391