1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for serializing `Function`s.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import saved_object_graph_pb2 22from tensorflow.python.framework import func_graph as func_graph_module 23from tensorflow.python.saved_model import nested_structure_coder 24 25 26def _serialize_function_spec(function_spec, coder): 27 """Serialize a FunctionSpec object into its proto representation.""" 28 proto = saved_object_graph_pb2.FunctionSpec() 29 proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec)) 30 proto.is_method = function_spec.is_method 31 proto.args_to_prepend.CopyFrom( 32 coder.encode_structure(function_spec.args_to_prepend)) 33 proto.kwargs_to_include.CopyFrom( 34 coder.encode_structure(function_spec.kwargs_to_include)) 35 proto.input_signature.CopyFrom( 36 coder.encode_structure(function_spec.input_signature)) 37 return proto 38 39 40def serialize_concrete_function(concrete_function, node_ids, coder): 41 """Build a SavedConcreteFunction.""" 42 bound_inputs = [] 43 try: 44 for capture in concrete_function.captured_inputs: 45 bound_inputs.append(node_ids[capture]) 46 except KeyError: 47 raise KeyError( 48 "Failed to add concrete function %s to object based saved model as it " 49 "captures tensor %s which is unsupported or not reachable from root. " 50 "One reason could be that a stateful object or a variable that the " 51 "function depends on is not assigned to an attribute of the serialized " 52 "trackable object " 53 "(see SaveTest.test_captures_unreachable_variable)." 54 % (concrete_function.name, capture)) 55 concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction() 56 structured_outputs = func_graph_module.convert_structure_to_signature( 57 concrete_function.structured_outputs) 58 concrete_function_proto.canonicalized_input_signature.CopyFrom( 59 coder.encode_structure(concrete_function.structured_input_signature)) 60 concrete_function_proto.output_signature.CopyFrom( 61 coder.encode_structure(structured_outputs)) 62 concrete_function_proto.bound_inputs.extend(bound_inputs) 63 return concrete_function_proto 64 65 66def serialize_bare_concrete_function(concrete_function): 67 """Build a SavedBareConcreteFunction.""" 68 # pylint: disable=protected-access 69 return saved_object_graph_pb2.SavedBareConcreteFunction( 70 concrete_function_name=concrete_function.name, 71 allowed_positional_arguments=concrete_function._num_positional_args, 72 argument_keywords=concrete_function._arg_keywords) 73 # pylint: enable=protected-access 74 75 76def serialize_function(function): 77 """Build a SavedFunction proto.""" 78 coder = nested_structure_coder.StructureCoder() 79 proto = saved_object_graph_pb2.SavedFunction() 80 81 function_spec_proto = _serialize_function_spec(function.function_spec, coder) 82 proto.function_spec.CopyFrom(function_spec_proto) 83 all_concrete_functions = \ 84 function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 85 for concrete_function in all_concrete_functions: 86 proto.concrete_functions.append(concrete_function.name) 87 return proto 88