• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Prototype decorator for defining legacy-graph-mode functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import weakref
23
24from tensorflow.core.protobuf import meta_graph_pb2
25from tensorflow.core.protobuf import struct_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.eager import lift_to_graph
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import func_graph
31from tensorflow.python.framework import importer
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import sparse_tensor
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.saved_model import nested_structure_coder
41from tensorflow.python.training.tracking import data_structures
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
44
45
46class VariableHolder(object):
47  """Holds variables for a python function."""
48
49  def __init__(self, fn=None, share_variables=False):
50    self._fn = fn
51
52    self._share_variables = share_variables
53    self._variables_by_name = data_structures.Mapping()
54
55  @property
56  def variables(self):
57    return self._variables_by_name
58
59  def variable_creator_scope(self, next_creator, **kwargs):
60    """Creates variables & adds them to collections to match legacy code."""
61    collections = kwargs.pop("collections", None)
62    v = None
63
64    # Get expected variable name.
65    with ops.name_scope(
66        kwargs.get("name", None), "Variable", skip_on_eager=False) as name:
67      variable_name = ops.name_from_scope_name(name)
68      kwargs["name"] = name
69
70    if self._share_variables:
71      v = self._variables_by_name.get(variable_name, None)
72
73    if v is None:
74      v = next_creator(**kwargs)
75      self._variables_by_name[variable_name] = v
76
77    if collections is None:
78      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
79    if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
80      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
81
82    ops.add_to_collections(collections, v)
83
84    return v
85
86  def __call__(self, *args, **kwargs):
87    return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
88
89  def call_with_variable_creator_scope(self, fn):
90
91    def wrapped(*args, **kwargs):
92      with variable_scope.variable_creator_scope(self.variable_creator_scope):
93        return fn(*args, **kwargs)
94
95    return wrapped
96
97
98def _get_element_from_tensor_info(tensor_info, graph):
99  """Simplified copy of the deprecated `get_tensor_from_tensor_info`."""
100  encoding = tensor_info.WhichOneof("encoding")
101  if encoding == "name":
102    # We may get operations here in some cases. TensorInfo is a bit of a
103    # misnomer if so.
104    return graph.as_graph_element(tensor_info.name)
105  elif encoding == "coo_sparse":
106    return sparse_tensor.SparseTensor(
107        graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name),
108        graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name),
109        graph.get_tensor_by_name(
110            tensor_info.coo_sparse.dense_shape_tensor_name))
111  elif encoding == "composite_tensor":
112    struct_coder = nested_structure_coder.StructureCoder()
113    spec_proto = struct_pb2.StructuredValue(
114        type_spec_value=tensor_info.composite_tensor.type_spec)
115    spec = struct_coder.decode_proto(spec_proto)
116    components = [graph.get_tensor_by_name(component.name) for component in
117                  tensor_info.composite_tensor.components]
118    return spec._from_components(components)  # pylint: disable=protected-access
119  else:
120    raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid "
121                     "encodings are 'name', 'coo_sparse', and "
122                     "'composite_tensor'.")
123
124
125def _lift_single_variable(old_variable, graph, variable_holder):
126  """Lifts `old_variable` out of the `FuncGraph` `graph`."""
127  new_variable = resource_variable_ops.UninitializedVariable(
128      shape=old_variable.shape,
129      dtype=old_variable.dtype,
130      name=old_variable.op.name,
131      trainable=old_variable.trainable,
132      extra_handle_data=old_variable.handle)
133  new_variable._initializer_op = old_variable._initializer_op  # pylint: disable=protected-access
134  graph.add_capture(new_variable.handle, old_variable.handle)
135  # Now that we've added the new variable to graph.captures,
136  # graph.capture will use that cached value and do some post-processing
137  # on the capture like recording it on the tape.
138  graph.capture(new_variable.handle)
139  # pylint: disable=protected-access
140  variable_name = new_variable.name.split(":")[0]
141  variable_holder._variables_by_name[variable_name] = new_variable
142  graph._weak_variables.append(weakref.ref(new_variable))
143  # pylint: enable=protected-access
144  graph.watch_variable(new_variable)
145  return new_variable
146
147
148def _lift_unlifted_variables(graph, variable_holder):
149  """Finds resource variables and lifts them into the outer context.
150
151  When we import a GraphDef inside a wrap_function, no Python graph building
152  code runs. This means we get VarHandleOps which create variable resources,
153  but no corresponding Python objects. Leaving them like this works but gives
154  the user no way to interact with or modify the variables outside the graph.
155
156  This method searches for variables and lifts them out as regular variable
157  objects when possible, indicating to the FuncGraph that they are captures.
158
159  Args:
160    graph: The FuncGraph to lift variables from.
161    variable_holder: A VariableHolder to record the lifted variables in.
162  """
163  with graph.as_default():
164    global_collection_variables = ops.get_collection(
165        ops.GraphKeys.GLOBAL_VARIABLES)
166    local_collection_variables = ops.get_collection(
167        ops.GraphKeys.LOCAL_VARIABLES)
168    existing_captures = {id(c) for c in graph.internal_captures}
169    lifted_variables = {}
170
171    def _should_lift_variable(v):
172      return ((v._in_graph_mode  # pylint: disable=protected-access
173               and v.graph.building_function)
174              and isinstance(v, resource_variable_ops.BaseResourceVariable)
175              and id(v.handle) not in existing_captures)
176
177    for old_variable in global_collection_variables:
178      if _should_lift_variable(old_variable):
179        new_variable = _lift_single_variable(
180            old_variable, graph, variable_holder)
181        lifted_variables[id(old_variable)] = new_variable
182        existing_captures.add(id(old_variable.handle))
183
184    for old_variable in local_collection_variables:
185      if _should_lift_variable(old_variable):
186        new_variable = _lift_single_variable(
187            old_variable, graph, variable_holder)
188        lifted_variables[id(old_variable)] = new_variable
189        existing_captures.add(id(old_variable.handle))
190        if new_variable._in_graph_mode:  # pylint: disable=protected-access
191          outer_graph = new_variable.graph
192          # Variables are added to the global collection by default. In this
193          # case we only want the variable in the local collection, so we'll pop
194          # it out.
195          global_collection = outer_graph.get_collection_ref(
196              ops.GraphKeys.GLOBAL_VARIABLES)
197          global_collection.remove(new_variable)
198          outer_graph.add_to_collection(
199              ops.GraphKeys.LOCAL_VARIABLES, new_variable)
200
201    # Update the FuncGraph's collections, partly for the user and partly so this
202    # function is idempotent when it runs again in prune() calls.
203    for collection_name in [
204        ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
205    ]:
206      mutable_collection = ops.get_collection_ref(collection_name)
207      for index, current in enumerate(mutable_collection):
208        mutable_collection[index] = lifted_variables.get(id(current), current)
209        if not resource_variable_ops.is_resource_variable(
210            mutable_collection[index]):
211          logging.log_first_n(
212              logging.WARN,
213              "Unable to create a python object for variable {} because it is "
214              "a reference variable. It may not be visible to training APIs. "
215              "If this is a problem, consider rebuilding the SavedModel after "
216              "running tf.compat.v1.enable_resource_variables().".format(
217                  mutable_collection[index]),
218              5)
219
220
221# TODO(allenl): make this trackable
222class WrappedFunction(function.ConcreteFunction):
223  """Wraps a tf V1 piece of code in a function."""
224
225  def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
226    self._variable_holder = variable_holder
227    _lift_unlifted_variables(fn_graph, variable_holder)
228    # We call __init__ after lifting variables so that the function's signature
229    # properly reflects the new captured inputs.
230    for f in fn_graph.as_graph_def().library.function:
231      context.context().add_function_def(f)
232    self._signature = signature
233    super(WrappedFunction, self).__init__(fn_graph, attrs=attrs)
234
235  def _call_impl(self, args, kwargs, cancellation_manager=None):
236    if self._arg_keywords is None:
237      if kwargs:
238        raise NotImplementedError(
239            "Keyword arguments are not supported when calling a "
240            f"wrap_function-decorated function. Got {kwargs}.")
241      if self._signature is not None:
242        args = list(args)
243        for i, arg in enumerate(args):
244          if isinstance(self._signature[i], tensor_spec.DenseSpec):
245            args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype)
246      return self._call_flat(args, self.captured_inputs)
247    else:
248      return super(WrappedFunction, self)._call_impl(
249          args, kwargs, cancellation_manager)
250
251  def prune(self, feeds, fetches, name=None, input_signature=None):
252    """Extract a subgraph of this function's underlying graph.
253
254    Wraps the subgraph in a new `WrappedFunction` object.
255
256    Args:
257      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
258      fetches: Possibly-nested Python data structure containing information
259        about outputs of the target subgraph. Each entry can either be a
260        `Tensor` object (for data outputs), an `Operation` object (for control
261        outputs), or a `TensorInfo` proto. Any additional shape/dtype
262        information provided in a `TensorInfo` and not present in the original
263        graph will be added to the returned subgraph.
264      name: (optional) Name to give to the underlying `FuncGraph` of the
265        returned object. If no name is provided, the graph's name will be
266        `"pruned"`.
267      input_signature: (optional) possibly-nested Python data structure
268        containing `TensorSpec` objects, with which to populate the returned
269        functions's `FuncGraph`'s `structured_input_signature` field.
270
271    Returns:
272      A new `WrappedFunction` object containing a copy of the portion of this
273        object's graph that goes from `feeds` to `fetches`.
274    """
275    # TODO(b/129646028): Add support for CompositeTensors.
276    name = name or "pruned"
277    flat_feeds = nest.flatten(feeds, expand_composites=True)
278    flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
279    for f in flat_feeds:
280      if not isinstance(f, ops.Tensor):
281        raise ValueError("All memebers of argument `feeds` must be tensors. "
282                         f"Got {f} with type {type(f)}.")
283
284    # Ignoring all feeds that are captures allows prune to be called
285    # using wrapped_func.inputs even when it uses variables
286    internal_captures = {id(c) for c in self.graph.internal_captures}
287    flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures]
288
289    operation_fetches = []
290    tensor_fetches = []
291    tensor_infos = []
292
293    def _fetch_preprocessing_callback(fetch):
294      """Extract out lists of ops, tensors, and tensor type info.
295
296      Turns TensorInfos into Tensors in the original `fetches` structure.
297      Also extracts ops from `fetches`.
298
299      Args:
300        fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
301          string identifying a Tensor or Operation.
302
303      Returns:
304        `fetch` converted to a Tensor.
305      """
306      if isinstance(fetch, ops.Operation):
307        operation_fetches.append(fetch)
308        return fetch
309      elif isinstance(fetch, meta_graph_pb2.TensorInfo):
310        tensor_infos.append(fetch)
311        decoded = _get_element_from_tensor_info(fetch, self._func_graph)
312        if (tensor_util.is_tf_type(decoded) or
313            isinstance(decoded, composite_tensor.CompositeTensor)):
314          tensor_fetches.append(decoded)
315        else:
316          operation_fetches.append(decoded)
317        return decoded
318      elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
319        tensor_fetches.append(fetch)
320        return fetch
321      else:
322        graph_element = self.graph.as_graph_element(fetch)
323        return _fetch_preprocessing_callback(graph_element)
324
325    fetches = nest.map_structure(_fetch_preprocessing_callback, fetches)
326
327    # Expand composite tensors into their component dense Tensors.
328    tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
329
330    for f in flat_feeds + tensor_fetches + operation_fetches:
331      if f.graph is not self._func_graph:
332        raise ValueError("Can only prune function whose feeds and fetches "
333                         f"from graph {self._func_graph}. Input "
334                         f"{f} is from a different graph {f.graph}.")
335    with self._func_graph.as_default():
336      pruned_graph = func_graph.FuncGraph(name)
337    lift_map = lift_to_graph.lift_to_graph(
338        operation_fetches + tensor_fetches,
339        pruned_graph,
340        sources=flat_feeds + self.graph.internal_captures,
341        base_graph=self._func_graph)
342
343    # Note that we add the component tensors of any composite tensors to the
344    # returned function's outputs list; the list must contain these component
345    # tensors, or the function's sparse outputs won't work properly.
346    pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
347    pruned_graph.control_outputs.extend(
348        [lift_map[operation] for operation in operation_fetches])
349    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
350    for external_capture, internal_capture in self.graph.captures:
351      pruned_graph.add_capture(external_capture, lift_map[internal_capture])
352    for ti in tensor_infos:
353      if ti.WhichOneof("encoding") == "name":  # Dense tensors only
354        t = pruned_graph.as_graph_element(ti.name)
355        if tensor_util.is_tf_type(t):
356          t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
357    # pylint: disable=protected-access
358    for f in self.graph._functions.values():
359      pruned_graph._add_function(f)
360    # pylint: enable=protected-access
361
362    pruned_graph.variables = self.graph.variables
363
364    def _structured_output_mapping(fetched):
365      """callback for `nest.map_structure()`"""
366      lifted = lift_map[fetched]
367      if isinstance(lifted, ops.Operation):
368        return None
369      return lifted
370
371    # expand_composites=True here causes composite tensors to be expanded
372    # into their component dense Tensors, mapped to the new graph, and then
373    # reconstituted into their original composite form.
374    pruned_graph.structured_outputs = nest.map_structure(
375        _structured_output_mapping, fetches, expand_composites=True)
376    pruned_graph.structured_input_signature = input_signature
377    pruned_fn = WrappedFunction(
378        pruned_graph, variable_holder=self._variable_holder)
379    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
380    # TODO(kathywu): Enable keyword arguments if an input signature is specified
381    pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
382    return pruned_fn
383
384
385def _filter_returned_ops(fn):
386  """Filtering out any ops returned by function.
387
388  Args:
389    fn: a function
390
391  Returns:
392    A tuple of (
393      Wrapped function that returns `None` in place of any ops,
394      dict that maps the index in the flat output structure to the returned op
395    )
396  """
397  returned_ops = {}
398
399  def wrap_and_filter_returned_ops(*args, **kwargs):
400    outputs = fn(*args, **kwargs)
401    flat_outputs = nest.flatten(outputs)
402    for n in range(len(flat_outputs)):
403      output = flat_outputs[n]
404      if isinstance(output, ops.Operation):
405        returned_ops[n] = output
406        flat_outputs[n] = None
407    return nest.pack_sequence_as(outputs, flat_outputs)
408
409  return wrap_and_filter_returned_ops, returned_ops
410
411
412class WrappedGraph(object):
413  """Class for wrapping multiple TF 1.X functions in a single graph.
414
415  Maintains a dictionary mapping names to wrapped functions. See
416  `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions.
417
418  Functions wrapped using this class have access to variables and collections
419  created in other wrapped functions, using the standard TF 1.X API (
420  `tf.compat.v1.get_variable` or
421  `tf.compat.v1.get_default_graph().get_collection(...)`)
422
423  Outside a function, variables and collections may be accessed using the
424  `variables` and `graph` properties.
425
426  Example:
427
428  ```
429  def add_v1(x):
430    with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE):
431      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
432    return v + x
433
434  def increment_var_v1(x):
435    with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE):
436      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
437    return v.assign_add(x)
438
439  g = WrappedGraph()
440  add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)])
441  increment_var = g.wrap_function(increment_var_v1,
442                                  [tf.TensorSpec([], tf.int32)])
443
444  assert len(g.variables) == 1
445  assert g.variables[0].numpy() == 0
446  increment_var(tf.constant(5))
447  assert g.variables[0].numpy() == 5
448
449  ```
450  """
451
452  def __init__(self, variable_holder=None, **kwargs):
453    self._variable_holder = (
454        variable_holder or VariableHolder(share_variables=True))
455
456    name = kwargs.pop("name", "wrapped_function_graph")
457    # Always start with empty collections, unless otherwise specified. Setting
458    # `collections=None` will copy the collections from the outer graph.
459    collections = kwargs.pop("collections", {})
460    self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)
461
462    self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
463    self._functions = {}
464
465  @property
466  def functions(self):
467    return self._functions
468
469  @property
470  def variables(self):
471    return self._variable_holder.variables
472
473  def wrap_function(self, fn, signature, name=None):
474    """Wraps a TF 1.X function and returns an eager-compatible function.
475
476    All functions wrapped in the same `WrappedGraph` will have access to the
477    same graph (`tf.compat.v1.get_default_graph` to get the graph object
478    within a function, or `WrappedGraph.graph` to get the graph outside a
479    function). Variables created within the function will be added to the
480    `variables` list.
481
482    Function inputs: All inputs to the function must be tensors (nested ok),
483    with their shapes and dtypes defined in the `signature` argument.
484
485    Function outputs:
486
487      * The 1.X function may return tensors, variables, and ops. The wrapped
488        eager-compatible function will always return tensors in the same nested
489        structure.
490      * Variables are replaced with a tensor containing the latest read values.
491      * Returned ops are executed, and replaced with None.
492      * The order of op execution and variable reads in the return is
493        nondeterministic. For example:
494
495        ```
496        def update_var(x):
497          v = tf.Variable(0)
498          op = tf.compat.v1.assign(v, x).op
499          return v, op
500
501        g = WrappedGraph()
502        fn = g.wrap_function(update_var)
503        read_value, _ = fn(tf.constant(3))
504        print(read_value.numpy())  # could be 0 or 3
505        print(g.variables[0].numpy()) # always 3
506        ```
507
508    To ensure that ops in the function are executed (e.g. ops added to the
509    `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns.
510
511    Args:
512      fn: a 1.X tensorflow function.
513      signature: a possibly nested sequence of `TensorSpecs` specifying the
514        shapes and dtypes of the arguments.
515      name: an optional string name for the function. The function will be saved
516        with key `name` in the `functions` dictionary.
517
518    Returns:
519      An eager-compatible function.
520    """
521    return self._wrap_function(fn, signature=signature, name=name)
522
523  def _wrap_function(self,
524                     fn,
525                     args=None,
526                     kwargs=None,
527                     signature=None,
528                     name=None):
529    """Internal wrap function method with extended func_graph arguments."""
530    fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
531        self._variable_holder.call_with_variable_creator_scope(fn))
532
533    func_graph.func_graph_from_py_func(
534        None,  # Name is unused.
535        fn_with_filter_and_scope,
536        args=args,
537        kwargs=kwargs,
538        signature=signature,
539        add_control_dependencies=False,
540        func_graph=self.graph)
541
542    # This code relies on questional behavior from `func_graph_from_py_func`.
543    # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
544    # and structured outputs are overwritten. Pretty sure this is a bug,
545    # because structured outputs doesn't match up with the outputs...
546    fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
547
548    # Return filtered ops to the flattened outputs.
549    flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
550    for index, op in returned_ops.items():
551      flat_fn_outputs[index] = op
552    fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
553                                       flat_fn_outputs)
554
555    name = name or fn.__name__
556    wrapped_function = self._wrapped_function.prune(
557        fn_inputs, fn_outputs, name, self.graph.structured_input_signature)
558    self._functions[name] = wrapped_function
559    return wrapped_function
560
561
562@tf_export(v1=["wrap_function"])
563def wrap_function(fn, signature, name=None):
564  """Wraps the TF 1.x function fn into a graph function.
565
566  The python function `fn` will be called once with symbolic arguments specified
567  in the `signature`, traced, and turned into a graph function. Any variables
568  created by `fn` will be owned by the object returned by `wrap_function`. The
569  resulting graph function can be called with tensors which match the
570  signature.
571
572  ```python
573  def f(x, do_add):
574    v = tf.Variable(5.0)
575    if do_add:
576      op = v.assign_add(x)
577    else:
578      op = v.assign_sub(x)
579    with tf.control_dependencies([op]):
580      return v.read_value()
581
582  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])
583
584  assert float(f_add(1.0)) == 6.0
585  assert float(f_add(1.0)) == 7.0
586
587  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
588  # of variables, and possibly different non-template arguments.
589  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])
590
591  assert float(f_sub(1.0)) == 4.0
592  assert float(f_sub(1.0)) == 3.0
593  ```
594
595  Both `tf.compat.v1.wrap_function` and `tf.function` create a callable
596  TensorFlow graph. But while `tf.function` runs all stateful operations
597  (e.g. `tf.print`) and sequences operations to provide the same semantics as
598  eager execution, `wrap_function` is closer to the behavior of `session.run` in
599  TensorFlow 1.x. It will not run any operations unless they are required to
600  compute the function's outputs, either through a data dependency or a control
601  dependency. Nor will it sequence operations.
602
603  Unlike `tf.function`, `wrap_function` will only trace the Python function
604  once. As with placeholders in TF 1.x, shapes and dtypes must be provided to
605  `wrap_function`'s `signature` argument.
606
607  Since it is only traced once, variables and state may be created inside the
608  function and owned by the function wrapper object.
609
610  Args:
611    fn: python function to be wrapped
612    signature: the placeholder and python arguments to be passed to the wrapped
613      function
614    name: Optional. The name of the function.
615
616  Returns:
617    the wrapped graph function.
618  """
619  holder = VariableHolder(fn)
620  func_graph_name = "wrapped_function"
621  if name is not None:
622    func_graph_name = "wrapped_function_" + name
623  return WrappedFunction(
624      func_graph.func_graph_from_py_func(
625          func_graph_name,
626          holder,
627          args=None,
628          kwargs=None,
629          signature=signature,
630          add_control_dependencies=False,
631          collections={}),
632      variable_holder=holder,
633      signature=signature)
634
635
636def function_from_graph_def(graph_def, inputs, outputs):
637  """Creates a ConcreteFunction from a GraphDef.
638
639  Args:
640    graph_def: A GraphDef to make a function out of.
641    inputs: A Tensor name or nested structure of names in `graph_def` which
642      should be inputs to the function.
643    outputs: A Tensor name or nested structure of names in `graph_def` which
644      should be outputs of the function.
645
646  Returns:
647    A ConcreteFunction.
648  """
649
650  def _imports_graph_def():
651    importer.import_graph_def(graph_def, name="")
652
653  wrapped_import = wrap_function(_imports_graph_def, [])
654  import_graph = wrapped_import.graph
655  return wrapped_import.prune(
656      nest.map_structure(import_graph.as_graph_element, inputs),
657      nest.map_structure(import_graph.as_graph_element, outputs))
658