1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for serializing `Function`s.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import saved_object_graph_pb2 22from tensorflow.python.eager import function as defun 23from tensorflow.python.framework import func_graph as func_graph_module 24from tensorflow.python.saved_model import nested_structure_coder 25from tensorflow.python.util import compat 26from tensorflow.python.util import nest 27 28 29def _serialize_function_spec(function_spec, coder): 30 """Serialize a FunctionSpec object into its proto representation.""" 31 if function_spec.is_method and not function_spec.fullargspec.args: 32 raise NotImplementedError( 33 "Missing support to serialize a method function without a named " 34 "'self' argument.") 35 proto = saved_object_graph_pb2.FunctionSpec() 36 37 # Intentionally skip encoding annotations of a function because function 38 # annotations are mainly for optional type checking during development 39 # and does not affect runtime behavior. 40 # https://www.python.org/dev/peps/pep-3107/ 41 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec 42 proto.fullargspec.CopyFrom( 43 coder.encode_structure( 44 function_spec.fullargspec._replace(annotations={}))) 45 46 proto.is_method = function_spec.is_method 47 proto.input_signature.CopyFrom( 48 coder.encode_structure(function_spec.input_signature)) 49 50 # See `tf.function` and the JitCompile proto for details. 51 proto.jit_compile = { 52 None: saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT, 53 True: saved_object_graph_pb2.FunctionSpec.JitCompile.ON, 54 False: saved_object_graph_pb2.FunctionSpec.JitCompile.OFF, 55 }.get(function_spec.jit_compile) 56 57 return proto 58 59 60def serialize_concrete_function(concrete_function, node_ids, coder): 61 """Build a SavedConcreteFunction.""" 62 bound_inputs = [] 63 try: 64 for capture in concrete_function.captured_inputs: 65 bound_inputs.append(node_ids[capture]) 66 except KeyError: 67 raise KeyError( 68 "Failed to add concrete function %s to object based saved model as it " 69 "captures tensor %s which is unsupported or not reachable from root. " 70 "One reason could be that a stateful object or a variable that the " 71 "function depends on is not assigned to an attribute of the serialized " 72 "trackable object " 73 "(see SaveTest.test_captures_unreachable_variable)." 74 % (concrete_function.name, capture)) 75 concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction() 76 structured_outputs = func_graph_module.convert_structure_to_signature( 77 concrete_function.structured_outputs) 78 concrete_function_proto.canonicalized_input_signature.CopyFrom( 79 coder.encode_structure(concrete_function.structured_input_signature)) 80 concrete_function_proto.output_signature.CopyFrom( 81 coder.encode_structure(structured_outputs)) 82 concrete_function_proto.bound_inputs.extend(bound_inputs) 83 return concrete_function_proto 84 85 86def serialize_bare_concrete_function(concrete_function, name_map): 87 """Build a SavedBareConcreteFunction.""" 88 # pylint: disable=protected-access 89 name = name_map.get(compat.as_text(concrete_function.name), 90 concrete_function.name) 91 proto = saved_object_graph_pb2.SavedBareConcreteFunction( 92 concrete_function_name=name, 93 allowed_positional_arguments=concrete_function._num_positional_args, 94 argument_keywords=concrete_function._arg_keywords) 95 if concrete_function._pre_initialized_function_spec is not None: 96 coder = nested_structure_coder.StructureCoder() 97 proto.function_spec.CopyFrom( 98 _serialize_function_spec( 99 concrete_function._pre_initialized_function_spec, coder)) 100 return proto 101 # pylint: enable=protected-access 102 103 104def serialize_function(function, name_map): 105 """Build a SavedFunction proto.""" 106 coder = nested_structure_coder.StructureCoder() 107 proto = saved_object_graph_pb2.SavedFunction() 108 109 function_spec_proto = _serialize_function_spec(function.function_spec, coder) 110 proto.function_spec.CopyFrom(function_spec_proto) 111 all_concrete_functions = \ 112 function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 113 for concrete_function in all_concrete_functions: 114 proto.concrete_functions.append( 115 name_map.get(compat.as_text(concrete_function.name), 116 concrete_function.name)) 117 return proto 118 119 120def wrap_cached_variables(concrete_function): 121 """Wraps the concrete function if it uses cached read tensors. 122 123 This function creates a new concrete function that captures variables 124 instead of the cached read tensors. 125 126 Args: 127 concrete_function: A Concrete function that maybe captures cached read 128 tensors. 129 130 Returns: 131 A concrete function that wraps the original concrete function, which 132 captures variables instead. If the original function did not capture any 133 cached values, then the function is not wrapped and the original object is 134 returned. 135 """ 136 outer_graph = func_graph_module.FuncGraph( 137 "{}_no_cache".format(concrete_function.graph.name)) 138 captures = concrete_function.graph._captures # pylint: disable=protected-access 139 mapped_captures = None 140 remapped_captures = {} 141 142 # Update the external captures to use read tensors generated in the outer 143 # graph. 144 with outer_graph.as_default(): 145 for capture, placeholder in concrete_function.graph.captures: 146 cached_variable = getattr(capture, "_cached_variable", None) 147 if cached_variable is None: 148 continue 149 cached_variable = cached_variable() 150 new_cached_value = cached_variable.read_value() 151 remapped_captures[id(capture)] = captures[id(capture)] 152 captures[id(capture)] = (new_cached_value, placeholder) 153 mapped_captures = True 154 155 if not mapped_captures: 156 return concrete_function 157 158 inner_concrete = defun.ConcreteFunction(concrete_function.graph) 159 160 def wrap_function(*args): 161 return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access 162 163 args = nest.flatten(concrete_function.structured_input_signature, 164 expand_composites=True) 165 func_graph_module.func_graph_from_py_func( 166 None, wrap_function, args=tuple(args), kwargs={}, 167 func_graph=outer_graph) 168 fn = defun.ConcreteFunction( 169 outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access 170 fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access 171 fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access 172 173 # Return the captures to their original values 174 for key, capture in remapped_captures.items(): 175 captures[key] = capture 176 return fn 177