• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Prototype decorator for defining legacy-graph-mode functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import weakref
23
24from tensorflow.python.eager import def_function
25from tensorflow.python.eager import function
26from tensorflow.python.eager import lift_to_graph
27from tensorflow.python.framework import func_graph
28from tensorflow.python.framework import importer
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import resource_variable_ops
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
35
36
37class VariableHolder(object):
38  """Holds variables for a python function."""
39
40  def __init__(self, fn=None, share_variables=False):
41    self._fn = fn
42
43    self._variables = []
44
45    self._share_variables = share_variables
46    self._variables_by_name = {}
47
48  @property
49  def variables(self):
50    return self._variables
51
52  def variable_creator_scope(self, next_creator, **kwargs):
53    """Creates variables & adds them to collections to match legacy code."""
54    collections = kwargs.pop("collections", None)
55    v = None
56
57    # Get expected variable name.
58    name = kwargs.get("name", None)
59    with ops.name_scope(name, "Variable") as name_scope:
60      name = name_scope
61
62    if self._share_variables:
63      v = self._variables_by_name.get(name, None)
64
65    if v is None:
66      v = next_creator(**kwargs)
67      self._variables.append(v)
68      if self._share_variables:
69        self._variables_by_name[name] = v
70
71    if collections is None:
72      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
73    if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
74      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
75
76    ops.add_to_collections(collections, v)
77
78    return v
79
80  def __call__(self, *args, **kwargs):
81    return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
82
83  def call_with_variable_creator_scope(self, fn):
84    def wrapped(*args, **kwargs):
85      with variable_scope.variable_creator_scope(self.variable_creator_scope):
86        return fn(*args, **kwargs)
87    return wrapped
88
89
90# TODO(allenl): make this trackable
91class WrappedFunction(function.ConcreteFunction):
92  """Wraps a tf V1 piece of code in a function."""
93
94  def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
95    super(WrappedFunction, self).__init__(
96        fn_graph, attrs=attrs, signature=signature)
97    self._variable_holder = variable_holder
98    if ops.executing_eagerly_outside_functions():
99      # TODO(allenl): Make this work in 1.x?
100      self._lift_unlifted_variables()
101
102  def _lift_unlifted_variables(self):
103    """Finds resource variables and lifts them into the outer context.
104
105    When we import a GraphDef inside a wrap_function, no Python graph building
106    code runs. This means we get VarHandleOps which create variable resources,
107    but no corresponding Python objects. Leaving them like this works but gives
108    the user no way to interact with or modify the variables outside the graph.
109
110    This method searches for variables and lifts them out as regular variable
111    objects when possible, indicating to the FuncGraph that they are captures.
112    """
113    with self.graph.as_default():
114      collection_variables = (
115          ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
116          + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
117      existing_captures = set(self.graph.internal_captures)
118      lifted_variables = {}
119      for old_variable in collection_variables:
120        if (old_variable._in_graph_mode  # pylint: disable=protected-access
121            and isinstance(old_variable,
122                           resource_variable_ops.ResourceVariable)):
123          if old_variable.handle in existing_captures:
124            continue
125          new_variable = def_function.UnliftedInitializerVariable(
126              array_ops.placeholder(
127                  name="unused_{}_initializer".format(old_variable.op.name),
128                  shape=old_variable.shape,
129                  dtype=old_variable.dtype),
130              name=old_variable.op.name,
131              trainable=old_variable.trainable)
132          self.graph.captures[new_variable.handle] = old_variable.handle
133          existing_captures.add(old_variable.handle)
134          lifted_variables[old_variable] = new_variable
135          # pylint: disable=protected-access
136          self._variable_holder._variables.append(new_variable)
137          self.graph._weak_variables.append(weakref.ref(new_variable))
138          # pylint: enable=protected-access
139      # Update the graph's collections, partly for the user and partly so this
140      # function is idempotent when it runs again in prune() calls.
141      for collection_name in [ops.GraphKeys.GLOBAL_VARIABLES,
142                              ops.GraphKeys.LOCAL_VARIABLES]:
143        mutable_collection = ops.get_collection_ref(collection_name)
144        for index, current in enumerate(mutable_collection):
145          mutable_collection[index] = lifted_variables.get(current, current)
146
147  def prune(self, feeds, fetches, name=None):
148    name = name or "pruned"
149    flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
150    for f in flat_feeds:
151      if not isinstance(f, ops.Tensor):
152        raise ValueError("Feeds must be tensors.")
153
154    # Ignoring all feeds that are captures allows prune to be called
155    # using wrapped_func.inputs even when it uses variables
156    internal_captures = self.graph.internal_captures
157    flat_feeds = [f for f in flat_feeds
158                  if f not in internal_captures]
159
160    tensor_fetches = []
161    operation_fetches = []
162    for f in flat_fetches:
163      if isinstance(f, ops.Tensor):
164        tensor_fetches.append(f)
165      elif isinstance(f, ops.Operation):
166        operation_fetches.append(f)
167      else:
168        raise ValueError("Fetches must be tensors or operations.")
169    for f in flat_feeds + flat_fetches:
170      if f.graph is not self._func_graph:
171        raise ValueError(
172            "Can only prune function whose feeds and fetches "
173            "are from this graph (%s). Tensor %s from graph %s" % (
174                self._func_graph, f, f.graph))
175    with self._func_graph.as_default():
176      pruned_graph = func_graph.FuncGraph(name)
177      with ops.control_dependencies(operation_fetches):
178        if tensor_fetches:
179          identity_fetches = array_ops.identity_n(tensor_fetches)
180          sink_tensor = identity_fetches[0]
181        else:
182          identity_fetches = []
183          sink_tensor = array_ops.zeros([])
184    lift_map = lift_to_graph.lift_to_graph(
185        [sink_tensor], pruned_graph, sources=flat_feeds + internal_captures)
186    for original_fetch, identity_fetch in zip(
187        tensor_fetches, identity_fetches):
188      lift_map[original_fetch] = lift_map[identity_fetch]
189    pruned_graph.outputs.extend(
190        lift_map[x] for x in flat_fetches if isinstance(x, ops.Tensor))
191    if not tensor_fetches:
192      pruned_graph.outputs.append(lift_map[sink_tensor])
193    for external_capture, internal_capture in self.graph.captures.items():
194      pruned_graph.captures[external_capture] = lift_map[internal_capture]
195    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
196    pruned_graph.inputs.extend(pruned_graph.captures.values())
197
198    pruned_graph.variables = self.graph.variables
199
200    def _structured_output_mapping(fetched):
201      lifted = lift_map[fetched]
202      if isinstance(lifted, ops.Operation):
203        return None
204      return lifted
205
206    pruned_graph.structured_outputs = nest.map_structure(
207        _structured_output_mapping, fetches)
208    pruned_fn = WrappedFunction(
209        pruned_graph, variable_holder=self._variable_holder)
210    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
211    pruned_fn._arg_keywords = []  # pylint: disable=protected-access
212    return pruned_fn
213
214
215class WrappedGraph(object):
216  """Class for wrapping multiple TF 1.X functions in a single graph.
217
218  Maintains a dictionary mapping names to wrapped functions. See
219  `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions.
220
221  Functions wrapped using this class have access to variables and collections
222  created in other wrapped functions, using the standard TF 1.X API (
223  `tf.compat.v1.get_variable` or
224  `tf.compat.v1.get_default_graph().get_collection(...)`)
225
226  Outside a function, variables and collections may be accessed using the
227  `variables` and `graph` properties.
228
229  Example:
230
231  ```
232  def add_v1(x):
233    with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE):
234      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
235    return v + x
236
237  def increment_var_v1(x):
238    with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE):
239      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
240    return v.assign_add(x)
241
242  g = WrappedGraph()
243  add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)])
244  increment_var = g.wrap_function(increment_var_v1,
245                                  [tf.TensorSpec([], tf.int32)])
246
247  assert len(g.variables) == 1
248  assert g.variables[0].numpy() == 0
249  increment_var(tf.constant(5))
250  assert g.variables[0].numpy() == 5
251
252  ```
253  """
254
255  def __init__(self, variable_holder=None, **kwargs):
256    self._variable_holder = (
257        variable_holder or VariableHolder(share_variables=True))
258
259    name = kwargs.pop("name", "wrapped_function_graph")
260    # Always start with empty collections, unless otherwise specified. Setting
261    # `collections=None` will copy the collections from the outer graph.
262    collections = kwargs.pop("collections", {})
263    self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)
264
265    self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
266    self._functions = {}
267
268  @property
269  def functions(self):
270    return self._functions
271
272  @property
273  def variables(self):
274    return self._variable_holder.variables
275
276  def wrap_function(self, fn, signature, name=None):
277    """Wrap a TF 1.X function and save to functions dictionary."""
278    func_graph.func_graph_from_py_func(
279        None,  # Name is unused.
280        self._variable_holder.call_with_variable_creator_scope(fn),
281        args=None, kwargs=None, signature=signature,
282        add_control_dependencies=False,
283        func_graph=self.graph)
284
285    # This code relies on questional behavior from `func_graph_from_py_func`.
286    # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
287    # and structured outputs are overwritten. Pretty sure this is a bug,
288    # because structured outputs doesn't match up with the outputs...
289    fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
290    fn_outputs = self.graph.structured_outputs
291
292    wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs)
293    name = name or fn.__name__
294    self._functions[name] = wrapped_function
295    return wrapped_function
296
297
298@tf_export(v1=["wrap_function"])
299def wrap_function(fn, signature, name=None):
300  """Wraps the TF 1.x function fn into a graph function.
301
302  The python function `fn` will be called once with symbolic arguments specified
303  in the `signature`, traced, and turned into a graph function. Any variables
304  created by `fn` will be owned by the object returned by `wrap_function`. The
305  resulting graph function can be called with tensors which match the
306  signature.
307
308  ```python
309  def f(x, do_add):
310    v = tf.Variable(5.0)
311    if do_add:
312      op = v.assign_add(x)
313    else:
314      op = v.assign_sub(x)
315    with tf.control_dependencies([op]):
316      return v.read_value()
317
318  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])
319
320  assert float(f_add(1.0)) == 6.0
321  assert float(f_add(1.0)) == 7.0
322
323  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
324  # of variables, and possibly different non-template arguments.
325  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])
326
327  assert float(f_sub(1.0)) == 4.0
328  assert float(f_sub(1.0)) == 3.0
329  ```
330
331  Both `tf.compat.v1.wrap_function` and `tf.function` create a callable
332  TensorFlow graph. But while `tf.function` runs all stateful operations
333  (e.g. `tf.print`) and sequences operations to provide the same semantics as
334  eager execution, `wrap_function` is closer to the behavior of `session.run` in
335  TensorFlow 1.x. It will not run any operations unless they are required to
336  compute the function's outputs, either through a data dependency or a control
337  dependency. Nor will it sequence operations.
338
339  Unlike `tf.function`, `wrap_function` will only trace the Python function
340  once. As with placeholders in TF 1.x, shapes and dtypes must be provided to
341  `wrap_function`'s `signature` argument.
342
343  Since it is only traced once, variables and state may be created inside the
344  function and owned by the function wrapper object.
345
346  Args:
347    fn: python function to be wrapped
348    signature: the placeholder and python arguments to be passed to the
349      wrapped function
350    name: Optional. The name of the function.
351
352  Returns:
353    the wrapped graph function.
354  """
355  holder = VariableHolder(fn)
356  func_graph_name = "wrapped_function"
357  if name is not None:
358    func_graph_name = "wrapped_function_" + name
359  return WrappedFunction(
360      func_graph.func_graph_from_py_func(
361          func_graph_name,
362          holder,
363          args=None, kwargs=None, signature=signature,
364          add_control_dependencies=False,
365          collections={}),
366      variable_holder=holder,
367      signature=signature)
368
369
370def function_from_graph_def(graph_def, inputs, outputs):
371  """Creates a ConcreteFunction from a GraphDef.
372
373  Args:
374    graph_def: A GraphDef to make a function out of.
375    inputs: A Tensor name or nested structure of names in `graph_def` which
376      should be inputs to the function.
377    outputs: A Tensor name or nested structure of names in `graph_def` which
378      should be outputs of the function.
379
380  Returns:
381    A ConcreteFunction.
382  """
383  def _imports_graph_def():
384    importer.import_graph_def(graph_def, name="")
385
386  wrapped_import = wrap_function(_imports_graph_def, [])
387  import_graph = wrapped_import.graph
388  return wrapped_import.prune(
389      nest.map_structure(import_graph.as_graph_element, inputs),
390      nest.map_structure(import_graph.as_graph_element, outputs))
391