• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for serializing `Function`s."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import saved_object_graph_pb2
22from tensorflow.python.eager import function as defun
23from tensorflow.python.framework import func_graph as func_graph_module
24from tensorflow.python.saved_model import nested_structure_coder
25from tensorflow.python.util import compat
26from tensorflow.python.util import nest
27
28
29def _serialize_function_spec(function_spec, coder):
30  """Serialize a FunctionSpec object into its proto representation."""
31  if function_spec.is_method and not function_spec.fullargspec.args:
32    raise NotImplementedError(
33        "Missing support to serialize a method function without a named "
34        "'self' argument.")
35  proto = saved_object_graph_pb2.FunctionSpec()
36
37  # Intentionally skip encoding annotations of a function because function
38  # annotations are mainly for optional type checking during development
39  # and does not affect runtime behavior.
40  # https://www.python.org/dev/peps/pep-3107/
41  # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
42  proto.fullargspec.CopyFrom(
43      coder.encode_structure(
44          function_spec.fullargspec._replace(annotations={})))
45
46  proto.is_method = function_spec.is_method
47  proto.input_signature.CopyFrom(
48      coder.encode_structure(function_spec.input_signature))
49
50  # See `tf.function` and the JitCompile proto for details.
51  proto.jit_compile = {
52      None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT,
53      True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON,
54      False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF,
55  }.get(function_spec.jit_compile)
56
57  return proto
58
59
60def serialize_concrete_function(concrete_function, node_ids, coder):
61  """Build a SavedConcreteFunction."""
62  bound_inputs = []
63  try:
64    for capture in concrete_function.captured_inputs:
65      bound_inputs.append(node_ids[capture])
66  except KeyError:
67    raise KeyError(
68        "Failed to add concrete function %s to object based saved model as it "
69        "captures tensor %s which is unsupported or not reachable from root. "
70        "One reason could be that a stateful object or a variable that the "
71        "function depends on is not assigned to an attribute of the serialized "
72        "trackable object "
73        "(see SaveTest.test_captures_unreachable_variable)."
74        % (concrete_function.name, capture))
75  concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
76  structured_outputs = func_graph_module.convert_structure_to_signature(
77      concrete_function.structured_outputs)
78  concrete_function_proto.canonicalized_input_signature.CopyFrom(
79      coder.encode_structure(concrete_function.structured_input_signature))
80  concrete_function_proto.output_signature.CopyFrom(
81      coder.encode_structure(structured_outputs))
82  concrete_function_proto.bound_inputs.extend(bound_inputs)
83  return concrete_function_proto
84
85
86def serialize_bare_concrete_function(concrete_function, name_map):
87  """Build a SavedBareConcreteFunction."""
88  # pylint: disable=protected-access
89  name = name_map.get(compat.as_text(concrete_function.name),
90                      concrete_function.name)
91  proto = saved_object_graph_pb2.SavedBareConcreteFunction(
92      concrete_function_name=name,
93      allowed_positional_arguments=concrete_function._num_positional_args,
94      argument_keywords=concrete_function._arg_keywords)
95  if concrete_function._pre_initialized_function_spec is not None:
96    coder = nested_structure_coder.StructureCoder()
97    proto.function_spec.CopyFrom(
98        _serialize_function_spec(
99            concrete_function._pre_initialized_function_spec, coder))
100  return proto
101  # pylint: enable=protected-access
102
103
104def serialize_function(function, name_map):
105  """Build a SavedFunction proto."""
106  coder = nested_structure_coder.StructureCoder()
107  proto = saved_object_graph_pb2.SavedFunction()
108
109  function_spec_proto = _serialize_function_spec(function.function_spec, coder)
110  proto.function_spec.CopyFrom(function_spec_proto)
111  all_concrete_functions = \
112      function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
113  for concrete_function in all_concrete_functions:
114    proto.concrete_functions.append(
115        name_map.get(compat.as_text(concrete_function.name),
116                     concrete_function.name))
117  return proto
118
119
120def wrap_cached_variables(concrete_function):
121  """Wraps the concrete function if it uses cached read tensors.
122
123  This function creates a new concrete function that captures variables
124  instead of the cached read tensors.
125
126  Args:
127    concrete_function: A Concrete function that maybe captures cached read
128      tensors.
129
130  Returns:
131    A concrete function that wraps the original concrete function, which
132    captures variables instead. If the original function did not capture any
133    cached values, then the function is not wrapped and the original object is
134    returned.
135  """
136  outer_graph = func_graph_module.FuncGraph(
137      "{}_no_cache".format(concrete_function.graph.name))
138  captures = concrete_function.graph._captures  # pylint: disable=protected-access
139  mapped_captures = None
140  remapped_captures = {}
141
142  # Update the external captures to use read tensors generated in the outer
143  # graph.
144  with outer_graph.as_default():
145    for capture, placeholder in concrete_function.graph.captures:
146      cached_variable = getattr(capture, "_cached_variable", None)
147      if cached_variable is None:
148        continue
149      cached_variable = cached_variable()
150      new_cached_value = cached_variable.read_value()
151      remapped_captures[id(capture)] = captures[id(capture)]
152      captures[id(capture)] = (new_cached_value, placeholder)
153      mapped_captures = True
154
155  if not mapped_captures:
156    return concrete_function
157
158  inner_concrete = defun.ConcreteFunction(concrete_function.graph)
159
160  def wrap_function(*args):
161    return inner_concrete._call_flat(args, inner_concrete.captured_inputs)  # pylint:disable=protected-access
162
163  args = nest.flatten(concrete_function.structured_input_signature,
164                      expand_composites=True)
165  func_graph_module.func_graph_from_py_func(
166      None, wrap_function, args=tuple(args), kwargs={},
167      func_graph=outer_graph)
168  fn = defun.ConcreteFunction(
169      outer_graph, function_spec=concrete_function._function_spec)  # pylint: disable=protected-access
170  fn._arg_keywords = concrete_function._arg_keywords  # pylint: disable=protected-access
171  fn._num_positional_args = concrete_function._num_positional_args  # pylint: disable=protected-access
172
173  # Return the captures to their original values
174  for key, capture in remapped_captures.items():
175    captures[key] = capture
176  return fn
177