# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tools for deserializing `Function`s.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import re from tensorflow.core.framework import function_pb2 from tensorflow.python.eager import def_function from tensorflow.python.eager import function as function_lib from tensorflow.python.framework import func_graph as func_graph_lib from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect def _is_tensor(t): return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable)) def _call_concrete_function(function, inputs): """Calls a restored Function with structured inputs. This differs from `function.__call__` in that inputs and outputs are structured and that it casts inputs to tensors if needed. Note: this does not checks that non-tensor inputs match. That should be done before via `_concrete_function_callable_with`. Args: function: ConcreteFunction to call. inputs: Structured inputs compatible with `function.graph.structured_input_signature`. Returns: The structured function output. """ expected_structure = function.graph.structured_input_signature flatten_inputs = nest.flatten_up_to(expected_structure, inputs) tensor_inputs = [] for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): tensor_inputs.append( ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) result = function._call_flat(tensor_inputs) # pylint: disable=protected-access if isinstance(result, ops.Operation): return None return result def _try_convert_to_tensor_spec(arg, dtype_hint): """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" try: # Note: try conversion in a FuncGraph to avoid poluting current context. with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) except (TypeError, ValueError): return None def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False else: if arg != expected: return False return True def _deserialize_function_spec(function_spec_proto, coder): """Deserialize a FunctionSpec object from its proto representation.""" typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec) fullargspec = tf_inspect.FullArgSpec( args=typeless_fullargspec.args, varargs=typeless_fullargspec.varargs, varkw=typeless_fullargspec.varkw, defaults=typeless_fullargspec.defaults, kwonlyargs=typeless_fullargspec.kwonlyargs, kwonlydefaults=typeless_fullargspec.kwonlydefaults, annotations=typeless_fullargspec.annotations) is_method = function_spec_proto.is_method args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend) kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include) input_signature = coder.decode_proto(function_spec_proto.input_signature) return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend, kwargs_to_include, input_signature) # TODO(allenl): The fact that we can't derive ConcreteFunction calling # conventions from the serialized input spec right now is unfortunate. Merging # these would be good, maybe by adding TensorSpec names to cache keys so renamed # keyword arguments would yield different ConcreteFunctions. def setup_bare_concrete_function(saved_bare_concrete_function, concrete_functions): """Makes a restored bare concrete function callable.""" # Bare concrete functions accept only flat lists of Tensors with unique # names. concrete_function = concrete_functions[ saved_bare_concrete_function.concrete_function_name] # pylint: disable=protected-access concrete_function._arg_keywords = ( saved_bare_concrete_function.argument_keywords) concrete_function._num_positional_args = ( saved_bare_concrete_function.allowed_positional_arguments) # pylint: enable=protected-access concrete_function.add_to_graph() return concrete_function class RestoredFunction(def_function.Function): """Wrapper class for a function that has been restored from saved state. See `def_function.Function`. """ def __init__(self, python_function, name, function_spec, concrete_functions): # TODO(mdan): We may enable autograph once exceptions are supported. super(RestoredFunction, self).__init__( python_function, name, autograph=False) self._concrete_functions = concrete_functions self._function_spec = function_spec def _list_all_concrete_functions_for_serialization(self): return self._concrete_functions def recreate_function(saved_function, concrete_functions): """Creates a `Function` from a `SavedFunction`. Args: saved_function: `SavedFunction` proto. concrete_functions: map from function name to `ConcreteFunction`. Returns: A `Function`. """ # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. coder = nested_structure_coder.StructureCoder() function_spec = _deserialize_function_spec(saved_function.function_spec, coder) def restored_function_body(*args, **kwargs): """Calls a restored function.""" # This is the format of function.graph.structured_input_signature. At this # point, the args and kwargs have already been canonicalized. inputs = (args, kwargs) # First try to find a concrete function that can be called without input # conversions. This allows one to pick a more specific trace in case there # was also a more expensive one that supported tensors. for allow_conversion in [False, True]: for function_name in saved_function.concrete_functions: function = concrete_functions[function_name] if _concrete_function_callable_with(function, inputs, allow_conversion): return _call_concrete_function(function, inputs) available_signatures = [ concrete_functions[function_name].graph.structured_input_signature for function_name in saved_function.concrete_functions ] raise ValueError( "Could not find matching function to call for inputs %r. " "Only existing signatures are %r." % (inputs, available_signatures)) concrete_function_objects = [] for concrete_function_name in saved_function.concrete_functions: concrete_function_objects.append(concrete_functions[concrete_function_name]) restored_function = RestoredFunction( restored_function_body, restored_function_body.__name__, function_spec, concrete_function_objects) return tf_decorator.make_decorator( restored_function_body, restored_function, decorator_argspec=function_spec.fullargspec) def load_function_def_library(library): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ functions = {} load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) func_graph = function_def_lib.function_def_to_graph(copy) for dep in _list_function_deps(fdef): functions[dep].add_to_graph(func_graph) func = function_lib.ConcreteFunction(func_graph) func.add_to_graph() functions[fdef.signature.name] = func # Also register the gradients in the current root context. with ops.init_scope(): func._register_gradient() # pylint: disable=protected-access return functions def _sort_function_defs(library): """Return a topologic sort of FunctionDefs in a library.""" edges = collections.defaultdict(list) in_count = collections.defaultdict(lambda: 0) for fdef in library.function: for dep in _list_function_deps(fdef): edges[dep].append(fdef.signature.name) in_count[fdef.signature.name] += 1 ready = [ fdef.signature.name for fdef in library.function if in_count[fdef.signature.name] == 0 ] output = [] while ready: node = ready.pop() output.append(node) for dest in edges[node]: in_count[dest] -= 1 if not in_count[dest]: ready.append(dest) if len(output) != len(library.function): failed_to_resolve = sorted(set(in_count.keys()) - set(output)) raise ValueError("There is a cyclic-dependency between functions. ", "Could not resolve %r." % (failed_to_resolve,)) reverse = {fdef.signature.name: fdef for fdef in library.function} return [reverse[x] for x in output] def _fix_fdef(orig_fdef, functions, shared_name_suffix): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. shared_name_suffix: A unique string for this load which helps to avoid `shared_name` collisions across loads. Two functions from the same load using the same `shared_name` still need to share, but functions from different loads with the same `shared_name` should not. Returns: A fixed copy of the original FunctionDef. """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) for node_def in fdef.node_def: if "_gradient_op_type" in node_def.attr: if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]: # TODO(andresp): This code assumes that the gradient registered for this # function call is the default gradient for the function and not a # custom one. fname = node_def.attr["f"].func.name node_def.attr["_gradient_op_type"].s = compat.as_bytes( functions[fname]._gradient_name) # pylint: disable=protected-access else: logging.warning("Importing a function (%s) with ops with custom " "gradients. Will likely fail if a gradient is " "requested.", fdef.signature.name) for _, attr_value in node_def.attr.items(): if attr_value.func.name: attr_value.func.name = functions[attr_value.func.name].name # TODO(b/124205571): Avoid accidental sharing and destruction of restored # resources. For now uniquify "shared_name" when loading functions to avoid # sharing. if "shared_name" in node_def.attr: node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef def _list_function_deps(fdef): # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both # when listing deps and when fixing them. `function_def_to_graph` also # requires fixes. deps = set() for node_def in fdef.node_def: for _, attr_value in node_def.attr.items(): if attr_value.WhichOneof("value") == "func": deps.add(attr_value.func.name) return deps def _clean_function_name(name): """Vanity function to keep the function names comprehensible.""" # Note: each time a function is wrapped into `function_lib.ConcreteFunction` # its name becomes "__inference__xyz". match = re.search(r"^__inference_(.*)_\d+$", name) if match: return match.group(1) else: return name