1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Exports a SavedModel from a Trackable Python object.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23 24from tensorflow.core.framework import versions_pb2 25from tensorflow.core.protobuf import meta_graph_pb2 26from tensorflow.core.protobuf import saved_model_pb2 27from tensorflow.core.protobuf import saved_object_graph_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.eager import function as defun 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import meta_graph 34from tensorflow.python.framework import ops 35from tensorflow.python.lib.io import file_io 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import resource_variable_ops 39from tensorflow.python.saved_model import builder_impl 40from tensorflow.python.saved_model import constants 41from tensorflow.python.saved_model import function_serialization 42from tensorflow.python.saved_model import nested_structure_coder 43from tensorflow.python.saved_model import revived_types 44from tensorflow.python.saved_model import signature_constants 45from tensorflow.python.saved_model import signature_def_utils 46from tensorflow.python.saved_model import signature_serialization 47from tensorflow.python.saved_model import tag_constants 48from tensorflow.python.saved_model import utils_impl 49from tensorflow.python.training.saving import functional_saver 50from tensorflow.python.training.tracking import base 51from tensorflow.python.training.tracking import graph_view 52from tensorflow.python.training.tracking import object_identity 53from tensorflow.python.training.tracking import tracking 54from tensorflow.python.training.tracking import util 55from tensorflow.python.util import compat 56from tensorflow.python.util.tf_export import tf_export 57 58_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) 59 60 61# A container for an EagerTensor constant which has been copied to the exported 62# Graph. 63_CapturedConstant = collections.namedtuple( 64 "_CapturedConstant", ["eager_tensor", "graph_tensor"]) 65 66 67class _AugmentedGraphView(graph_view.ObjectGraphView): 68 """An extendable graph which also tracks functions attached to objects. 69 70 Extensions through `add_object` appear in the object graph and any checkpoints 71 generated from it, even if they are not dependencies of the node they were 72 attached to in the saving program. For example a `.signatures` attribute is 73 added to exported SavedModel root objects without modifying the root object 74 itself. 75 76 Also tracks functions attached to objects in the graph, through the caching 77 `list_functions` method. Enumerating functions only through this method 78 ensures that we get a consistent view of functions, even if object attributes 79 create new functions every time they are accessed. 80 """ 81 82 def __init__(self, root): 83 super(_AugmentedGraphView, self).__init__(root) 84 # Object -> (name -> dep) 85 self._extra_dependencies = object_identity.ObjectIdentityDictionary() 86 self._functions = object_identity.ObjectIdentityDictionary() 87 88 def add_object(self, parent_node, name_in_parent, subgraph_root): 89 """Attach an object to `parent_node`, overriding any existing dependency.""" 90 self._extra_dependencies.setdefault( 91 parent_node, {})[name_in_parent] = subgraph_root 92 93 def list_dependencies(self, obj): 94 """Overrides a parent method to include `add_object` objects.""" 95 extra_dependencies = self._extra_dependencies.get(obj, {}) 96 used_names = set() 97 for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj): 98 used_names.add(name) 99 if name in extra_dependencies: 100 yield base.TrackableReference(name, extra_dependencies[name]) 101 else: 102 yield base.TrackableReference(name, dep) 103 for name, dep in extra_dependencies.items(): 104 if name in used_names: 105 continue 106 yield base.TrackableReference(name, dep) 107 108 def list_functions(self, obj): 109 obj_functions = self._functions.get(obj, None) 110 if obj_functions is None: 111 obj_functions = obj._list_functions_for_serialization() # pylint: disable=protected-access 112 self._functions[obj] = obj_functions 113 return obj_functions 114 115 116class _SaveableView(object): 117 """Provides a frozen view over a trackable root. 118 119 This class helps creating a single stable view over an object to save. The 120 saving code should access properties and functions via this class and not via 121 the original object as there are cases where an object construct their 122 trackable attributes and functions dynamically per call and will yield 123 different objects if invoked more than once. 124 125 Changes to the graph, for example adding objects, must happen in 126 `checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is 127 constructed. Changes after the `_SaveableView` has been constructed will be 128 ignored. 129 """ 130 131 def __init__(self, checkpoint_view): 132 self.checkpoint_view = checkpoint_view 133 trackable_objects, node_ids, slot_variables = ( 134 self.checkpoint_view.objects_ids_and_slot_variables()) 135 self.nodes = trackable_objects 136 self.node_ids = node_ids 137 self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() 138 self.slot_variables = slot_variables 139 self.concrete_functions = [] 140 141 # Also add `Function`s as nodes. 142 nodes_without_functions = list(self.nodes) 143 seen_function_names = set() 144 for node in nodes_without_functions: 145 for function in checkpoint_view.list_functions(node).values(): 146 if function not in self.node_ids: 147 self.node_ids[function] = len(self.nodes) 148 self.nodes.append(function) 149 if isinstance(function, def_function.Function): 150 # Force listing the concrete functions for the side effects: 151 # - populate the cache for functions that have an input_signature 152 # and have not been called. 153 # - force side effects of creation of concrete functions, e.g. create 154 # variables on first run. 155 concrete_functions = ( 156 function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access 157 else: 158 concrete_functions = [function] 159 for concrete_function in concrete_functions: 160 if concrete_function.name not in seen_function_names: 161 seen_function_names.add(concrete_function.name) 162 self.concrete_functions.append(concrete_function) 163 164 @property 165 def root(self): 166 return self.nodes[0] 167 168 def fill_object_graph_proto(self, proto): 169 """Populate the nodes, children and slot_variables of a SavedObjectGraph.""" 170 for node_id, node in enumerate(self.nodes): 171 assert self.node_ids[node] == node_id 172 object_proto = proto.nodes.add() 173 object_proto.slot_variables.extend(self.slot_variables.get(node, ())) 174 if isinstance(node, (def_function.Function, defun.ConcreteFunction, 175 _CapturedConstant)): 176 continue 177 for child in self.checkpoint_view.list_dependencies(node): 178 child_proto = object_proto.children.add() 179 child_proto.node_id = self.node_ids[child.ref] 180 child_proto.local_name = child.name 181 for local_name, ref_function in ( 182 self.checkpoint_view.list_functions(node).items()): 183 child_proto = object_proto.children.add() 184 child_proto.node_id = self.node_ids[ref_function] 185 child_proto.local_name = local_name 186 187 def map_resources(self): 188 """Makes new resource handle ops corresponding to existing resource tensors. 189 190 Creates resource handle ops in the current default graph, whereas 191 `accessible_objects` will be from an eager context. Resource mapping adds 192 resource handle ops to the main GraphDef of a SavedModel, which allows the 193 C++ loader API to interact with variables. 194 195 Returns: 196 A tuple of (object_map, resource_map, asset_info): 197 object_map: A dictionary mapping from object in `accessible_objects` to 198 replacement objects created to hold the new resource tensors. 199 resource_map: A dictionary mapping from resource tensors extracted from 200 `accessible_objects` to newly created resource tensors. 201 asset_info: An _AssetInfo tuple describing external assets referenced 202 from accessible_objects. 203 """ 204 # Only makes sense when adding to the export Graph 205 assert not context.executing_eagerly() 206 # TODO(allenl): Handle MirroredVariables and other types of variables which 207 # may need special casing. 208 object_map = object_identity.ObjectIdentityDictionary() 209 resource_map = {} 210 asset_info = _AssetInfo( 211 asset_defs=[], 212 asset_initializers_by_resource={}, 213 asset_filename_map={}, 214 asset_index={}) 215 for node_id, obj in enumerate(self.nodes): 216 if isinstance(obj, tracking.TrackableResource): 217 new_resource = obj._create_resource() # pylint: disable=protected-access 218 resource_map[obj.resource_handle] = new_resource 219 self.captured_tensor_node_ids[obj.resource_handle] = node_id 220 elif resource_variable_ops.is_resource_variable(obj): 221 new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) 222 object_map[obj] = new_variable 223 resource_map[obj.handle] = new_variable.handle 224 self.captured_tensor_node_ids[obj.handle] = node_id 225 elif isinstance(obj, tracking.TrackableAsset): 226 _process_asset(obj, asset_info, resource_map) 227 self.captured_tensor_node_ids[obj.asset_path] = node_id 228 229 for concrete_function in self.concrete_functions: 230 for capture in concrete_function.captured_inputs: 231 if (isinstance(capture, ops.EagerTensor) 232 and capture.dtype not in _UNCOPIABLE_DTYPES 233 and capture not in self.captured_tensor_node_ids): 234 copied_tensor = constant_op.constant(capture.numpy()) 235 node_id = len(self.nodes) 236 node = _CapturedConstant( 237 eager_tensor=capture, graph_tensor=copied_tensor) 238 self.nodes.append(node) 239 self.node_ids[capture] = node_id 240 self.node_ids[node] = node_id 241 self.captured_tensor_node_ids[capture] = node_id 242 resource_map[capture] = copied_tensor 243 244 return object_map, resource_map, asset_info 245 246 247def _tensor_dict_to_tensorinfo(tensor_dict): 248 return {key: utils_impl.build_tensor_info_internal(value) 249 for key, value in tensor_dict.items()} 250 251 252def _map_captures_to_created_tensors( 253 original_captures, resource_map): 254 """Maps eager tensors captured by a function to Graph resources for export. 255 256 Args: 257 original_captures: A dictionary mapping from tensors captured by the 258 function to interior placeholders for those tensors (inside the function 259 body). 260 resource_map: A dictionary mapping from resource tensors owned by the eager 261 context to resource tensors in the exported graph. 262 263 Returns: 264 A list of stand-in tensors which belong to the exported graph, corresponding 265 to the function's captures. 266 267 Raises: 268 AssertionError: If the function references a resource which is not part of 269 `resource_map`. 270 """ 271 export_captures = [] 272 for exterior, interior in original_captures.items(): 273 mapped_resource = resource_map.get(exterior, None) 274 if mapped_resource is None: 275 raise AssertionError( 276 ("Tried to export a function which references untracked object {}." 277 "TensorFlow objects (e.g. tf.Variable) captured by functions must " 278 "be tracked by assigning them to an attribute of a tracked object " 279 "or assigned to an attribute of the main object directly.") 280 .format(interior)) 281 export_captures.append(mapped_resource) 282 return export_captures 283 284 285def _map_function_arguments_to_created_inputs( 286 function_arguments, signature_key, function_name): 287 """Creates exterior placeholders in the exported graph for function arguments. 288 289 Functions have two types of inputs: tensors captured from the outside (eager) 290 context, and arguments to the function which we expect to receive from the 291 user at each call. `_map_captures_to_created_tensors` replaces 292 captured tensors with stand-ins (typically these are resource dtype tensors 293 associated with variables). `_map_function_inputs_to_created_inputs` runs over 294 every argument, creating a new placeholder for each which will belong to the 295 exported graph rather than the function body. 296 297 Args: 298 function_arguments: A list of argument placeholders in the function body. 299 signature_key: The name of the signature being exported, for error messages. 300 function_name: The name of the function, for error messages. 301 302 Returns: 303 A tuple of (mapped_inputs, exterior_placeholders) 304 mapped_inputs: A list with entries corresponding to `function_arguments` 305 containing all of the inputs of the function gathered from the exported 306 graph (both captured resources and arguments). 307 exterior_argument_placeholders: A dictionary mapping from argument names 308 to placeholders in the exported graph, containing the explicit arguments 309 to the function which a user is expected to provide. 310 311 Raises: 312 ValueError: If argument names are not unique. 313 """ 314 # `exterior_argument_placeholders` holds placeholders which are outside the 315 # function body, directly contained in a MetaGraph of the SavedModel. The 316 # function body itself contains nearly identical placeholders used when 317 # running the function, but these exterior placeholders allow Session-based 318 # APIs to call the function using feeds and fetches which name Tensors in the 319 # MetaGraph. 320 exterior_argument_placeholders = {} 321 mapped_inputs = [] 322 for placeholder in function_arguments: 323 # `export_captures` contains an exhaustive set of captures, so if we don't 324 # find the input there then we now know we have an argument. 325 user_input_name = compat.as_str_any( 326 placeholder.op.get_attr("_user_specified_name")) 327 # If the internal placeholders for a function have names which were 328 # uniquified by TensorFlow, then a single user-specified argument name 329 # must refer to multiple Tensors. The resulting signatures would be 330 # confusing to call. Instead, we throw an exception telling the user to 331 # specify explicit names. 332 if user_input_name != placeholder.op.name: 333 # This should be unreachable, since concrete functions may not be 334 # generated with non-unique argument names. 335 raise ValueError( 336 ("Got non-flat/non-unique argument names for SavedModel " 337 "signature '{}': more than one argument to '{}' was named '{}'. " 338 "Signatures have one Tensor per named input, so to have " 339 "predictable names Python functions used to generate these " 340 "signatures should avoid *args and Tensors in nested " 341 "structures unless unique names are specified for each. Use " 342 "tf.TensorSpec(..., name=...) to provide a name for a Tensor " 343 "input.") 344 .format(signature_key, compat.as_str_any(function_name), 345 user_input_name)) 346 arg_placeholder = array_ops.placeholder( 347 shape=placeholder.shape, 348 dtype=placeholder.dtype, 349 name="{}_{}".format(signature_key, user_input_name)) 350 exterior_argument_placeholders[user_input_name] = arg_placeholder 351 mapped_inputs.append(arg_placeholder) 352 return mapped_inputs, exterior_argument_placeholders 353 354 355def _call_function_with_mapped_captures(function, args, resource_map): 356 """Calls `function` in the exported graph, using mapped resource captures.""" 357 export_captures = _map_captures_to_created_tensors( 358 function.graph.captures, resource_map) 359 mapped_inputs = args + export_captures 360 # Calls the function quite directly, since we have new captured resource 361 # tensors we need to feed in which weren't part of the original function 362 # definition. 363 # pylint: disable=protected-access 364 outputs = function._build_call_outputs( 365 function._inference_function.call(context.context(), mapped_inputs)) 366 return outputs 367 368 369def _generate_signatures(signature_functions, resource_map): 370 """Validates and calls `signature_functions` in the default graph. 371 372 Args: 373 signature_functions: A dictionary mapping string keys to concrete TensorFlow 374 functions (e.g. from `signature_serialization.canonicalize_signatures`) 375 which will be used to generate SignatureDefs. 376 resource_map: A dictionary mapping from resource tensors in the eager 377 context to resource tensors in the Graph being exported. This dictionary 378 is used to re-bind resources captured by functions to tensors which will 379 exist in the SavedModel. 380 381 Returns: 382 Each function in the `signature_functions` dictionary is called with 383 placeholder Tensors, generating a function call operation and output 384 Tensors. The placeholder Tensors, the function call operation, and the 385 output Tensors from the function call are part of the default Graph. 386 387 This function then returns a dictionary with the same structure as 388 `signature_functions`, with the concrete functions replaced by SignatureDefs 389 implicitly containing information about how to call each function from a 390 TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference 391 the generated placeholders and Tensor outputs by name. 392 393 The caller is expected to include the default Graph set while calling this 394 function as a MetaGraph in a SavedModel, including the returned 395 SignatureDefs as part of that MetaGraph. 396 """ 397 signatures = {} 398 for signature_key, function in sorted(signature_functions.items()): 399 if function.graph.captures: 400 argument_inputs = function.graph.inputs[:-len(function.graph.captures)] 401 else: 402 argument_inputs = function.graph.inputs 403 mapped_inputs, exterior_argument_placeholders = ( 404 _map_function_arguments_to_created_inputs( 405 argument_inputs, signature_key, function.name)) 406 outputs = _call_function_with_mapped_captures( 407 function, mapped_inputs, resource_map) 408 signatures[signature_key] = signature_def_utils.build_signature_def( 409 _tensor_dict_to_tensorinfo(exterior_argument_placeholders), 410 _tensor_dict_to_tensorinfo(outputs), 411 method_name=signature_constants.PREDICT_METHOD_NAME) 412 return signatures 413 414 415def _trace_resource_initializers(accessible_objects): 416 """Create concrete functions from `TrackableResource` objects.""" 417 resource_initializers = [] 418 419 def _wrap_initializer(obj): 420 obj._initialize() # pylint: disable=protected-access 421 return constant_op.constant(1.) # Dummy control output 422 423 def _wrap_obj_initializer(obj): 424 return lambda: _wrap_initializer(obj) 425 426 for obj in accessible_objects: 427 if isinstance(obj, tracking.TrackableResource): 428 resource_initializers.append(def_function.function( 429 _wrap_obj_initializer(obj), 430 # All inputs are captures. 431 input_signature=[]).get_concrete_function()) 432 return resource_initializers 433 434 435_AssetInfo = collections.namedtuple( 436 "_AssetInfo", [ 437 # List of AssetFileDef protocol buffers 438 "asset_defs", 439 # Map from asset variable resource Tensors to their init ops 440 "asset_initializers_by_resource", 441 # Map from base asset filenames to full paths 442 "asset_filename_map", 443 # Map from TrackableAsset to index of corresponding AssetFileDef 444 "asset_index"]) 445 446 447def _process_asset(trackable_asset, asset_info, resource_map): 448 """Add `trackable_asset` to `asset_info` and `resource_map`.""" 449 original_variable = trackable_asset.asset_path 450 with context.eager_mode(): 451 original_path = original_variable.numpy() 452 path = builder_impl.get_asset_filename_to_add( 453 asset_filepath=original_path, 454 asset_filename_map=asset_info.asset_filename_map) 455 # TODO(andresp): Instead of mapping 1-1 between trackable asset 456 # and asset in the graph def consider deduping the assets that 457 # point to the same file. 458 asset_path_initializer = array_ops.placeholder( 459 shape=original_variable.shape, 460 dtype=dtypes.string, 461 name="asset_path_initializer") 462 asset_variable = resource_variable_ops.ResourceVariable( 463 asset_path_initializer) 464 asset_info.asset_filename_map[path] = original_path 465 asset_def = meta_graph_pb2.AssetFileDef() 466 asset_def.filename = path 467 asset_def.tensor_info.name = asset_path_initializer.name 468 asset_info.asset_defs.append(asset_def) 469 asset_info.asset_initializers_by_resource[original_variable] = ( 470 asset_variable.initializer) 471 asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1 472 resource_map[original_variable] = asset_variable 473 474 475def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions): 476 """Generates a MetaGraph which calls `signature_functions`. 477 478 Args: 479 meta_graph_def: The MetaGraphDef proto to fill. 480 saveable_view: The _SaveableView being exported. 481 signature_functions: A dictionary mapping signature keys to concrete 482 functions containing signatures to add to the MetaGraph. 483 484 Returns: 485 An _AssetInfo, which contains information to help creating the SavedModel. 486 """ 487 # List objects from the eager context to make sure Optimizers give us the 488 # right Graph-dependent variables. 489 accessible_objects = saveable_view.nodes 490 resource_initializer_functions = _trace_resource_initializers( 491 accessible_objects) 492 exported_graph = ops.Graph() 493 resource_initializer_ops = [] 494 with exported_graph.as_default(): 495 object_map, resource_map, asset_info = saveable_view.map_resources() 496 for resource_initializer_function in resource_initializer_functions: 497 asset_dependencies = [] 498 for capture in resource_initializer_function.graph.external_captures: 499 asset_initializer = asset_info.asset_initializers_by_resource.get( 500 capture, None) 501 if asset_initializer is not None: 502 asset_dependencies.append(asset_initializer) 503 with ops.control_dependencies(asset_dependencies): 504 resource_initializer_ops.append( 505 _call_function_with_mapped_captures( 506 resource_initializer_function, [], resource_map)) 507 resource_initializer_ops.extend( 508 asset_info.asset_initializers_by_resource.values()) 509 with ops.control_dependencies(resource_initializer_ops): 510 init_op = control_flow_ops.no_op() 511 # Add the same op to the main_op collection and to the init_op 512 # signature. The collection is for compatibility with older loader APIs; 513 # only one will be executed. 514 meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( 515 init_op.name) 516 meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( 517 signature_def_utils.op_signature_def( 518 init_op, constants.INIT_OP_SIGNATURE_KEY)) 519 520 # Saving an object-based checkpoint again gathers variables. We need to do the 521 # gathering from the eager context so Optimizers save the right set of 522 # variables, but want any operations associated with the save/restore to be in 523 # the exported graph (thus the `to_graph` argument). 524 saver = functional_saver.Saver( 525 saveable_view.checkpoint_view.frozen_saveable_objects( 526 object_map=object_map, to_graph=exported_graph)) 527 528 with exported_graph.as_default(): 529 signatures = _generate_signatures(signature_functions, resource_map) 530 for concrete_function in saveable_view.concrete_functions: 531 concrete_function.add_to_graph() 532 saver_def = saver.to_proto() 533 meta_graph_def.saver_def.CopyFrom(saver_def) 534 graph_def = exported_graph.as_graph_def(add_shapes=True) 535 536 meta_graph_def.graph_def.CopyFrom(graph_def) 537 meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) 538 meta_graph_def.asset_file_def.extend(asset_info.asset_defs) 539 for signature_key, signature in signatures.items(): 540 meta_graph_def.signature_def[signature_key].CopyFrom(signature) 541 meta_graph.strip_graph_default_valued_attrs(meta_graph_def) 542 return asset_info, exported_graph 543 544 545def _serialize_object_graph(saveable_view, asset_file_def_index): 546 """Save a SavedObjectGraph proto for `root`.""" 547 # SavedObjectGraph is similar to the TrackableObjectGraph proto in the 548 # checkpoint. It will eventually go into the SavedModel. 549 proto = saved_object_graph_pb2.SavedObjectGraph() 550 saveable_view.fill_object_graph_proto(proto) 551 552 coder = nested_structure_coder.StructureCoder() 553 for concrete_function in saveable_view.concrete_functions: 554 serialized = function_serialization.serialize_concrete_function( 555 concrete_function, saveable_view.captured_tensor_node_ids, coder) 556 if serialized is not None: 557 proto.concrete_functions[concrete_function.name].CopyFrom( 558 serialized) 559 560 for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): 561 _write_object_proto(obj, obj_proto, asset_file_def_index) 562 return proto 563 564 565def _write_object_proto(obj, proto, asset_file_def_index): 566 """Saves an object into SavedObject proto.""" 567 if isinstance(obj, tracking.TrackableAsset): 568 proto.asset.SetInParent() 569 proto.asset.asset_file_def_index = asset_file_def_index[obj] 570 elif resource_variable_ops.is_resource_variable(obj): 571 proto.variable.SetInParent() 572 proto.variable.trainable = obj.trainable 573 proto.variable.dtype = obj.dtype.as_datatype_enum 574 proto.variable.shape.CopyFrom(obj.shape.as_proto()) 575 elif isinstance(obj, def_function.Function): 576 proto.function.CopyFrom( 577 function_serialization.serialize_function(obj)) 578 elif isinstance(obj, defun.ConcreteFunction): 579 proto.bare_concrete_function.CopyFrom( 580 function_serialization.serialize_bare_concrete_function(obj)) 581 elif isinstance(obj, _CapturedConstant): 582 proto.constant.operation = obj.graph_tensor.op.name 583 elif isinstance(obj, tracking.TrackableResource): 584 proto.resource.SetInParent() 585 else: 586 registered_type_proto = revived_types.serialize(obj) 587 if registered_type_proto is None: 588 # Fallback for types with no matching registration 589 registered_type_proto = saved_object_graph_pb2.SavedUserObject( 590 identifier="_generic_user_object", 591 version=versions_pb2.VersionDef( 592 producer=1, min_consumer=1, bad_consumers=[])) 593 proto.user_object.CopyFrom(registered_type_proto) 594 595 596@tf_export("saved_model.save", 597 v1=["saved_model.save", "saved_model.experimental.save"]) 598def save(obj, export_dir, signatures=None): 599 # pylint: disable=line-too-long 600 """Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md). 601 602 Example usage: 603 604 ```python 605 class Adder(tf.train.Checkpoint): 606 607 @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 608 def add(self, x): 609 return x + x + 1. 610 611 to_export = Adder() 612 tf.saved_model.save(to_export, '/tmp/adder') 613 ``` 614 615 The resulting SavedModel is then servable with an input named "x", its value 616 having any shape and dtype float32. 617 618 The optional `signatures` argument controls which methods in `obj` will be 619 available to programs which consume `SavedModel`s, for example serving 620 APIs. Python functions may be decorated with 621 `@tf.function(input_signature=...)` and passed as signatures directly, or 622 lazily with a call to `get_concrete_function` on the method decorated with 623 `@tf.function`. 624 625 If the `signatures` argument is omitted, `obj` will be searched for 626 `@tf.function`-decorated methods. If exactly one `@tf.function` is found, that 627 method will be used as the default signature for the SavedModel. This behavior 628 is expected to change in the future, when a corresponding 629 `tf.saved_model.load` symbol is added. At that point signatures will be 630 completely optional, and any `@tf.function` attached to `obj` or its 631 dependencies will be exported for use with `load`. 632 633 When invoking a signature in an exported SavedModel, `Tensor` arguments are 634 identified by name. These names will come from the Python function's argument 635 names by default. They may be overridden by specifying a `name=...` argument 636 in the corresponding `tf.TensorSpec` object. Explicit naming is required if 637 multiple `Tensor`s are passed through a single argument to the Python 638 function. 639 640 The outputs of functions used as `signatures` must either be flat lists, in 641 which case outputs will be numbered, or a dictionary mapping string keys to 642 `Tensor`, in which case the keys will be used to name outputs. 643 644 Signatures are available in objects returned by `tf.saved_model.load` as a 645 `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` 646 on an object with a custom `.signatures` attribute will raise an exception. 647 648 Since `tf.keras.Model` objects are also Trackable, this function can be 649 used to export Keras models. For example, exporting with a signature 650 specified: 651 652 ```python 653 class Model(tf.keras.Model): 654 655 @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) 656 def serve(self, serialized): 657 ... 658 659 m = Model() 660 tf.saved_model.save(m, '/tmp/saved_model/') 661 ``` 662 663 Exporting from a function without a fixed signature: 664 665 ```python 666 class Model(tf.keras.Model): 667 668 @tf.function 669 def call(self, x): 670 ... 671 672 m = Model() 673 tf.saved_model.save( 674 m, '/tmp/saved_model/', 675 signatures=m.call.get_concrete_function( 676 tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp"))) 677 ``` 678 679 `tf.keras.Model` instances constructed from inputs and outputs already have a 680 signature and so do not require a `@tf.function` decorator or a `signatures` 681 argument. If neither are specified, the model's forward pass is exported. 682 683 ```python 684 x = input_layer.Input((4,), name="x") 685 y = core.Dense(5, name="out")(x) 686 model = training.Model(x, y) 687 tf.saved_model.save(model, '/tmp/saved_model/') 688 # The exported SavedModel takes "x" with shape [None, 4] and returns "out" 689 # with shape [None, 5] 690 ``` 691 692 Variables must be tracked by assigning them to an attribute of a tracked 693 object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers 694 from `tf.keras.layers`, optimizers from `tf.train`) track their variables 695 automatically. This is the same tracking scheme that `tf.train.Checkpoint` 696 uses, and an exported `Checkpoint` object may be restored as a training 697 checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's 698 "variables/" subdirectory. Currently variables are the only stateful objects 699 supported by `tf.saved_model.save`, but others (e.g. tables) will be supported 700 in the future. 701 702 `tf.function` does not hard-code device annotations from outside the function 703 body, instead using the calling context's device. This means for example that 704 exporting a model which runs on a GPU and serving it on a CPU will generally 705 work, with some exceptions. `tf.device` annotations inside the body of the 706 function will be hard-coded in the exported model; this type of annotation is 707 discouraged. Device-specific operations, e.g. with "cuDNN" in the name or with 708 device-specific layouts, may cause issues. Currently a `DistributionStrategy` 709 is another exception: active distribution strategies will cause device 710 placements to be hard-coded in a function. Exporting a single-device 711 computation and importing under a `DistributionStrategy` is not currently 712 supported, but may be in the future. 713 714 SavedModels exported with `tf.saved_model.save` [strip default-valued 715 attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) 716 automatically, which removes one source of incompatibilities when the consumer 717 of a SavedModel is running an older TensorFlow version than the 718 producer. There are however other sources of incompatibilities which are not 719 handled automatically, such as when the exported model contains operations 720 which the consumer does not have definitions for. 721 722 The current implementation of `tf.saved_model.save` targets serving use-cases, 723 but omits information which will be necessary for the planned future 724 implementation of `tf.saved_model.load`. Exported models using the current 725 `save` implementation, and other existing SavedModels, will not be compatible 726 with `tf.saved_model.load` when it is implemented. Further, `save` will in the 727 future attempt to export `@tf.function`-decorated methods which it does not 728 currently inspect, so some objects which are exportable today will raise 729 exceptions on export in the future (e.g. due to complex/non-serializable 730 default arguments). Such backwards-incompatible API changes are expected only 731 prior to the TensorFlow 2.0 release. 732 733 Args: 734 obj: A trackable object to export. 735 export_dir: A directory in which to write the SavedModel. 736 signatures: Optional, either a `tf.function` with an input signature 737 specified or the result of `f.get_concrete_function` on a 738 `@tf.function`-decorated function `f`, in which case `f` will be used to 739 generate a signature for the SavedModel under the default serving 740 signature key. `signatures` may also be a dictionary, in which case it 741 maps from signature keys to either `tf.function` instances with input 742 signatures or concrete functions. The keys of such a dictionary may be 743 arbitrary strings, but will typically be from the 744 `tf.saved_model.signature_constants` module. 745 746 Raises: 747 ValueError: If `obj` is not trackable. 748 749 @compatibility(eager) 750 Not supported when graph building. From TensorFlow 1.x, 751 `tf.enable_eager_execution()` must run first. May not be called from within a 752 function body. 753 @end_compatibility 754 """ 755 if not context.executing_eagerly(): 756 with ops.init_scope(): 757 if context.executing_eagerly(): 758 raise AssertionError( 759 "tf.saved_model.save is not supported inside a traced " 760 "@tf.function. Move the call to the outer eagerly-executed " 761 "context.") 762 else: 763 raise AssertionError( 764 "tf.saved_model.save is not supported when graph building. " 765 "tf.enable_eager_execution() must run first when calling it from " 766 "TensorFlow 1.x.") 767 # pylint: enable=line-too-long 768 if not isinstance(obj, base.Trackable): 769 raise ValueError( 770 "Expected a Trackable object for export, got {}.".format(obj)) 771 772 checkpoint_graph_view = _AugmentedGraphView(obj) 773 if signatures is None: 774 signatures = signature_serialization.find_function_to_export( 775 checkpoint_graph_view) 776 777 signatures = signature_serialization.canonicalize_signatures(signatures) 778 signature_serialization.validate_saveable_view(checkpoint_graph_view) 779 signature_map = signature_serialization.create_signature_map(signatures) 780 checkpoint_graph_view.add_object( 781 parent_node=checkpoint_graph_view.root, 782 name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, 783 subgraph_root=signature_map) 784 785 # Use _SaveableView to provide a frozen listing of properties and functions. 786 # Note we run this twice since, while constructing the view the first time 787 # there can be side effects of creating variables. 788 _ = _SaveableView(checkpoint_graph_view) 789 saveable_view = _SaveableView(checkpoint_graph_view) 790 791 # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x 792 # compatible (no sessions) and share it with this export API rather than 793 # making a SavedModel proto and writing it directly. 794 saved_model = saved_model_pb2.SavedModel() 795 meta_graph_def = saved_model.meta_graphs.add() 796 object_saver = util.TrackableSaver(checkpoint_graph_view) 797 asset_info, exported_graph = _fill_meta_graph_def( 798 meta_graph_def, saveable_view, signatures) 799 saved_model.saved_model_schema_version = ( 800 constants.SAVED_MODEL_SCHEMA_VERSION) 801 # So far we've just been generating protocol buffers with no I/O. Now we write 802 # the checkpoint, copy assets into the assets directory, and write out the 803 # SavedModel proto itself. 804 utils_impl.get_or_create_variables_dir(export_dir) 805 object_saver.save(utils_impl.get_variables_path(export_dir)) 806 builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, 807 export_dir) 808 path = os.path.join( 809 compat.as_bytes(export_dir), 810 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 811 object_graph_proto = _serialize_object_graph( 812 saveable_view, asset_info.asset_index) 813 meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) 814 file_io.write_string_to_file(path, saved_model.SerializeToString()) 815 # Clean reference cycles so repeated export()s don't make work for the garbage 816 # collector. Before this point we need to keep references to captured 817 # constants in the saved graph. 818 ops.dismantle_graph(exported_graph) 819