1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Exports a SavedModel from a Trackable Python object.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import gc 24import os 25 26from absl import logging 27from tensorflow.core.framework import versions_pb2 28from tensorflow.core.protobuf import meta_graph_pb2 29from tensorflow.core.protobuf import saved_model_pb2 30from tensorflow.core.protobuf import saved_object_graph_pb2 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function as defun 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import error_interpolation 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import meta_graph 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.framework import versions 42from tensorflow.python.lib.io import file_io 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import resource_variable_ops 46from tensorflow.python.platform import tf_logging 47from tensorflow.python.saved_model import builder_impl 48from tensorflow.python.saved_model import constants 49from tensorflow.python.saved_model import function_serialization 50from tensorflow.python.saved_model import nested_structure_coder 51from tensorflow.python.saved_model import revived_types 52from tensorflow.python.saved_model import save_context 53from tensorflow.python.saved_model import save_options 54from tensorflow.python.saved_model import signature_constants 55from tensorflow.python.saved_model import signature_def_utils 56from tensorflow.python.saved_model import signature_serialization 57from tensorflow.python.saved_model import tag_constants 58from tensorflow.python.saved_model import utils_impl 59from tensorflow.python.training.saving import checkpoint_options 60from tensorflow.python.training.saving import functional_saver 61from tensorflow.python.training.saving import saveable_object_util 62from tensorflow.python.training.tracking import base 63from tensorflow.python.training.tracking import graph_view 64from tensorflow.python.training.tracking import tracking 65from tensorflow.python.training.tracking import util 66from tensorflow.python.util import compat 67from tensorflow.python.util import object_identity 68from tensorflow.python.util.tf_export import tf_export 69 70_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) 71 72# A container for an EagerTensor constant which has been copied to the exported 73# Graph. 74_CapturedConstant = collections.namedtuple("_CapturedConstant", 75 ["eager_tensor", "graph_tensor"]) 76 77# Number of untraced functions to display to user in warning message. 78_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5 79 80 81class _AugmentedGraphView(graph_view.ObjectGraphView): 82 """An extendable graph which also tracks functions attached to objects. 83 84 Extensions through `add_object` appear in the object graph and any checkpoints 85 generated from it, even if they are not dependencies of the node they were 86 attached to in the saving program. For example a `.signatures` attribute is 87 added to exported SavedModel root objects without modifying the root object 88 itself. 89 90 Also tracks functions attached to objects in the graph, through the caching 91 `list_functions` method. Enumerating functions only through this method 92 ensures that we get a consistent view of functions, even if object attributes 93 create new functions every time they are accessed. 94 """ 95 96 def __init__(self, root): 97 if (not context.executing_eagerly() and not ops.inside_function()): 98 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 99 else: 100 saveables_cache = None 101 super(_AugmentedGraphView, self).__init__(root, saveables_cache) 102 # Object -> (name -> dep) 103 self._extra_dependencies = object_identity.ObjectIdentityDictionary() 104 self._functions = object_identity.ObjectIdentityDictionary() 105 # Cache shared between objects in the same object graph. This is passed to 106 # each trackable object's `_list_extra_dependencies_for_serialization` and 107 # `_list_functions_for_serialization` function. 108 self._serialization_cache = object_identity.ObjectIdentityDictionary() 109 110 def add_object(self, parent_node, name_in_parent, subgraph_root): 111 """Attach an object to `parent_node`, overriding any existing dependency.""" 112 self._extra_dependencies.setdefault(parent_node, 113 {})[name_in_parent] = subgraph_root 114 115 def list_dependencies(self, obj): 116 """Overrides a parent method to include `add_object` objects.""" 117 extra_dependencies = self.list_extra_dependencies(obj) 118 extra_dependencies.update(self._extra_dependencies.get(obj, {})) 119 120 used_names = set() 121 for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj): 122 used_names.add(name) 123 if name in extra_dependencies: 124 # Extra dependencies (except for `.signatures`, which is always added 125 # when saving) should not have naming conflicts with dependencies 126 # defined by the user. 127 if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME: 128 raise ValueError( 129 "Error when exporting object {} of with identifier={}. The object" 130 " has an attribute named {}, which is reserved. List of all " 131 "reserved attributes: {}".format( 132 obj, 133 obj._object_identifier, # pylint: disable=protected-access 134 name, 135 extra_dependencies.keys())) 136 yield base.TrackableReference(name, extra_dependencies[name]) 137 else: 138 yield base.TrackableReference(name, dep) 139 for name, dep in extra_dependencies.items(): 140 if name in used_names: 141 continue 142 yield base.TrackableReference(name, dep) 143 144 def list_extra_dependencies(self, obj): 145 return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access 146 self._serialization_cache) 147 148 def list_functions(self, obj): 149 obj_functions = self._functions.get(obj, None) 150 if obj_functions is None: 151 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access 152 self._serialization_cache) 153 self._functions[obj] = obj_functions 154 return obj_functions 155 156 157class _SaveableView(object): 158 """Provides a frozen view over a trackable root. 159 160 This class helps to create a single stable view over an object to save. The 161 saving code should access properties and functions via this class and not via 162 the original object as there are cases where an object construct their 163 trackable attributes and functions dynamically per call and will yield 164 different objects if invoked more than once. 165 166 Changes to the graph, for example adding objects, must happen in 167 `checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is 168 constructed. Changes after the `_SaveableView` has been constructed will be 169 ignored. 170 """ 171 172 def __init__(self, checkpoint_view, options, wrapped_functions=None): 173 """Initializes a SaveableView. 174 175 Args: 176 checkpoint_view: A GraphView object. 177 options: A SaveOptions instance. 178 wrapped_functions: Dictionary that maps concrete functions to functions 179 that do not capture cached variable values. 180 """ 181 182 self.checkpoint_view = checkpoint_view 183 self._options = options 184 # Maps functions -> wrapped functions that capture variables 185 self._wrapped_functions = wrapped_functions or {} 186 # Run through the nodes in the object graph first for side effects of 187 # creating variables. 188 self._trace_all_concrete_functions() 189 190 (self._trackable_objects, self.node_paths, self._node_ids, 191 self._slot_variables) = ( 192 self.checkpoint_view.objects_ids_and_slot_variables_and_paths()) 193 self._initialize_nodes_and_concrete_functions() 194 195 # Maps names of concrete functions in the object to names of wrapped 196 # functions. When writing the SavedFunction protos, the names of the 197 # wrapped functions should be used in place of the original functions. 198 self.function_name_map = { 199 compat.as_text(original.name): compat.as_text(wrapped.name) 200 for original, wrapped in self._wrapped_functions.items()} 201 self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() 202 203 def _initialize_nodes_and_concrete_functions(self): 204 """Creates graph with nodes for trackable objects and functions. 205 206 Adds functions for each trackable object to `self.nodes` and associated 207 concrete functions to `self.concrete_functions` for serialization. Also adds 208 the object's save and restore functions for loading values from checkpoint. 209 """ 210 self.nodes = list(self._trackable_objects) 211 self.concrete_functions = [] 212 self._seen_function_names = set() 213 self._untraced_functions = [] 214 # Maps node -> local name -> (save function, restore function) 215 self._saveable_objects_map = object_identity.ObjectIdentityDictionary() 216 217 for obj in self._trackable_objects: 218 for function in self.checkpoint_view.list_functions(obj).values(): 219 self._add_function_to_graph(function) 220 # Resource (and TPU/Mirrored) variables are automatically revived with 221 # their saveables defined, so there is no need to trace the save 222 # and restore functions. 223 if resource_variable_ops.is_resource_variable(obj): 224 continue 225 # Trace object save and restore functions to populate `saveables_map` 226 # field in the SavedModel proto. 227 saveable_map = saveable_object_util.trace_save_restore_functions(obj) 228 if saveable_map: 229 for save_fn, restore_fn in saveable_map.values(): 230 self._add_function_to_graph(save_fn) 231 self._add_function_to_graph(restore_fn) 232 self._saveable_objects_map[obj] = saveable_map 233 234 if self._untraced_functions: 235 logging.warning( 236 "Found untraced functions such as %s while saving (showing %d of %d)." 237 " These functions will not be directly callable after loading.", 238 ", ".join(self._untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]), 239 min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self._untraced_functions)), 240 len(self._untraced_functions)) 241 242 def _add_function_to_graph(self, function): 243 """Adds function to serialize to graph.""" 244 # Updates self.nodes, self._node_ids, self.concrete_functions, 245 # and self._untraced_functions. 246 if function not in self._node_ids: 247 self._node_ids[function] = len(self.nodes) 248 # Add the function to nodes as well. 249 self.nodes.append(function) 250 if isinstance(function, def_function.Function): 251 concrete_functions = ( 252 function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access 253 else: 254 concrete_functions = [function] 255 if not concrete_functions: 256 self._untraced_functions.append(function._name) # pylint: disable=protected-access 257 for concrete_function in concrete_functions: 258 if concrete_function.name not in self._seen_function_names: 259 self.concrete_functions.append(concrete_function) 260 self._seen_function_names.add(concrete_function.name) 261 262 def _trace_all_concrete_functions(self): 263 """Trace concrete functions to force side-effects. 264 265 Lists the concrete functions in order to: 266 - populate the cache for functions that have an input_signature 267 and have not been called 268 - force side effects of creation of concrete functions, e.g. create 269 variables on first run. 270 """ 271 for obj in self.checkpoint_view.list_objects(): 272 for function in self.checkpoint_view.list_functions(obj).values(): 273 if isinstance(function, def_function.Function): 274 function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 275 276 @property 277 def root(self): 278 return self.nodes[0] 279 280 def fill_object_graph_proto(self, proto): 281 """Populate the nodes, children and slot_variables of a SavedObjectGraph.""" 282 for node_id, node in enumerate(self.nodes): 283 assert self._node_ids[node] == node_id 284 object_proto = proto.nodes.add() 285 object_proto.slot_variables.extend(self._slot_variables.get(node, ())) 286 if isinstance( 287 node, 288 (def_function.Function, defun.ConcreteFunction, _CapturedConstant)): 289 continue 290 for child in self.checkpoint_view.list_dependencies(node): 291 child_proto = object_proto.children.add() 292 child_proto.node_id = self._node_ids[child.ref] 293 child_proto.local_name = child.name 294 for local_name, ref_function in ( 295 self.checkpoint_view.list_functions(node).items()): 296 child_proto = object_proto.children.add() 297 child_proto.node_id = self._node_ids[ref_function] 298 child_proto.local_name = local_name 299 300 if node not in self._saveable_objects_map: 301 continue 302 303 for local_name, (save_fn, restore_fn) in ( 304 self._saveable_objects_map[node].items()): 305 saveable_object_proto = object_proto.saveable_objects[local_name] 306 saveable_object_proto.save_function = self._node_ids[save_fn] 307 saveable_object_proto.restore_function = self._node_ids[restore_fn] 308 309 def map_resources(self): 310 """Makes new resource handle ops corresponding to existing resource tensors. 311 312 Creates resource handle ops in the current default graph, whereas 313 `accessible_objects` will be from an eager context. Resource mapping adds 314 resource handle ops to the main GraphDef of a SavedModel, which allows the 315 C++ loader API to interact with resources. 316 317 Returns: 318 A tuple of (object_map, resource_map, asset_info): 319 object_map: A dictionary mapping from object in `accessible_objects` to 320 replacement objects created to hold the new resource tensors. 321 resource_map: A dictionary mapping from resource tensors extracted from 322 `accessible_objects` to newly created resource tensors. 323 asset_info: An _AssetInfo tuple describing external assets referenced 324 from accessible_objects. 325 """ 326 # Only makes sense when adding to the export Graph 327 assert not context.executing_eagerly() 328 # TODO(allenl): Handle MirroredVariables and other types of variables which 329 # may need special casing. 330 object_map = object_identity.ObjectIdentityDictionary() 331 resource_map = {} 332 asset_info = _AssetInfo( 333 asset_defs=[], 334 asset_initializers_by_resource={}, 335 asset_filename_map={}, 336 asset_index={}) 337 338 for node_id, obj in enumerate(self.nodes): 339 if isinstance(obj, tracking.Asset): 340 _process_asset(obj, asset_info, resource_map) 341 self.captured_tensor_node_ids[obj.asset_path] = node_id 342 elif isinstance(obj, base.Trackable): 343 node_object_map, node_resource_map = obj._map_resources(self._options) # pylint: disable=protected-access 344 for capturable in node_resource_map.keys(): 345 self.captured_tensor_node_ids[capturable] = node_id 346 object_map.update(node_object_map) 347 resource_map.update(node_resource_map) 348 349 # Note: some concrete functions can have been realized when tracing other 350 # functions, and might closure-capture tensors from their parent functions. 351 # This is normal, but it means those concrete functions can't be serialized 352 # as their own independent endpoints, so we filter them out here. 353 bad_functions = [] 354 for concrete_function in self.concrete_functions: 355 if not concrete_function.graph.saveable: 356 raise ValueError( 357 ("Unable to save function {name} for the following reason(s):\n" + 358 "\n".join(concrete_function.graph.saving_errors)).format( 359 name=concrete_function.name)) 360 for capture in concrete_function.captured_inputs: 361 if (tensor_util.is_tf_type(capture) and 362 capture.dtype not in _UNCOPIABLE_DTYPES and 363 capture not in self.captured_tensor_node_ids): 364 if hasattr(capture, "_cached_variable"): 365 if concrete_function not in self._wrapped_functions: 366 wrapped = self._wrapped_functions[concrete_function] = ( 367 function_serialization.wrap_cached_variables( 368 concrete_function)) 369 self.function_name_map[compat.as_text(concrete_function.name)] = ( 370 compat.as_text(wrapped.name)) 371 continue 372 capture_constant_value = tensor_util.constant_value(capture) 373 if capture_constant_value is None: 374 bad_functions.append(concrete_function) 375 continue 376 copied_tensor = constant_op.constant(capture_constant_value) 377 node_id = len(self.nodes) 378 node = _CapturedConstant( 379 eager_tensor=capture, graph_tensor=copied_tensor) 380 self.nodes.append(node) 381 self._node_ids[capture] = node_id 382 self._node_ids[node] = node_id 383 self.captured_tensor_node_ids[capture] = node_id 384 resource_map[capture] = copied_tensor 385 386 self.concrete_functions = [ 387 self._wrapped_functions.get(x, x) for x in self.concrete_functions 388 if x not in bad_functions 389 ] 390 return object_map, resource_map, asset_info 391 392 393def _tensor_dict_to_tensorinfo(tensor_dict): 394 return { 395 key: utils_impl.build_tensor_info_internal(value) 396 for key, value in tensor_dict.items() 397 } 398 399 400def _map_captures_to_created_tensors(original_captures, resource_map): 401 """Maps eager tensors captured by a function to Graph resources for export. 402 403 Args: 404 original_captures: A dictionary mapping from tensors captured by the 405 function to interior placeholders for those tensors (inside the function 406 body). 407 resource_map: A dictionary mapping from resource tensors owned by the eager 408 context to resource tensors in the exported graph. 409 410 Returns: 411 A list of stand-in tensors which belong to the exported graph, corresponding 412 to the function's captures. 413 414 Raises: 415 AssertionError: If the function references a resource which is not part of 416 `resource_map`. 417 """ 418 export_captures = [] 419 for exterior, interior in original_captures: 420 mapped_resource = resource_map.get(exterior, None) 421 if mapped_resource is None: 422 trackable_referrers = [] 423 # Try to figure out where the resource came from by iterating over objects 424 # which reference it. This is slow and doesn't help us figure out how to 425 # match it to other objects when loading the SavedModel as a checkpoint, 426 # so we can't continue saving. But we can at least tell the user what 427 # needs attaching. 428 for primary_referrer in gc.get_referrers(exterior): 429 if isinstance(primary_referrer, base.Trackable): 430 trackable_referrers.append(primary_referrer) 431 for secondary_referrer in gc.get_referrers(primary_referrer): 432 if isinstance(secondary_referrer, base.Trackable): 433 trackable_referrers.append(secondary_referrer) 434 raise AssertionError( 435 ("Tried to export a function which references untracked resource {}. " 436 "TensorFlow objects (e.g. tf.Variable) captured by functions must " 437 "be tracked by assigning them to an attribute of a tracked object " 438 "or assigned to an attribute of the main object directly.\n\n" 439 "Trackable Python objects referring to this tensor " 440 "(from gc.get_referrers, limited to two hops):\n{}" 441 ).format(interior, 442 "\n".join([repr(obj) for obj in trackable_referrers]))) 443 export_captures.append(mapped_resource) 444 return export_captures 445 446 447def _map_function_arguments_to_created_inputs(function_arguments, signature_key, 448 function_name): 449 """Creates exterior placeholders in the exported graph for function arguments. 450 451 Functions have two types of inputs: tensors captured from the outside (eager) 452 context, and arguments to the function which we expect to receive from the 453 user at each call. `_map_captures_to_created_tensors` replaces 454 captured tensors with stand-ins (typically these are resource dtype tensors 455 associated with variables). `_map_function_inputs_to_created_inputs` runs over 456 every argument, creating a new placeholder for each which will belong to the 457 exported graph rather than the function body. 458 459 Args: 460 function_arguments: A list of argument placeholders in the function body. 461 signature_key: The name of the signature being exported, for error messages. 462 function_name: The name of the function, for error messages. 463 464 Returns: 465 A tuple of (mapped_inputs, exterior_placeholders) 466 mapped_inputs: A list with entries corresponding to `function_arguments` 467 containing all of the inputs of the function gathered from the exported 468 graph (both captured resources and arguments). 469 exterior_argument_placeholders: A dictionary mapping from argument names 470 to placeholders in the exported graph, containing the explicit arguments 471 to the function which a user is expected to provide. 472 473 Raises: 474 ValueError: If argument names are not unique. 475 """ 476 # `exterior_argument_placeholders` holds placeholders which are outside the 477 # function body, directly contained in a MetaGraph of the SavedModel. The 478 # function body itself contains nearly identical placeholders used when 479 # running the function, but these exterior placeholders allow Session-based 480 # APIs to call the function using feeds and fetches which name Tensors in the 481 # MetaGraph. 482 exterior_argument_placeholders = {} 483 mapped_inputs = [] 484 for placeholder in function_arguments: 485 # `export_captures` contains an exhaustive set of captures, so if we don't 486 # find the input there then we now know we have an argument. 487 user_input_name = compat.as_str_any( 488 placeholder.op.get_attr("_user_specified_name")) 489 # If the internal placeholders for a function have names which were 490 # uniquified by TensorFlow, then a single user-specified argument name 491 # must refer to multiple Tensors. The resulting signatures would be 492 # confusing to call. Instead, we throw an exception telling the user to 493 # specify explicit names. 494 if user_input_name != placeholder.op.name: 495 # This should be unreachable, since concrete functions may not be 496 # generated with non-unique argument names. 497 raise ValueError( 498 ("Got non-flat/non-unique argument names for SavedModel " 499 "signature '{}': more than one argument to '{}' was named '{}'. " 500 "Signatures have one Tensor per named input, so to have " 501 "predictable names Python functions used to generate these " 502 "signatures should avoid *args and Tensors in nested " 503 "structures unless unique names are specified for each. Use " 504 "tf.TensorSpec(..., name=...) to provide a name for a Tensor " 505 "input.").format(signature_key, compat.as_str_any(function_name), 506 user_input_name)) 507 arg_placeholder = array_ops.placeholder( 508 shape=placeholder.shape, 509 dtype=placeholder.dtype, 510 name="{}_{}".format(signature_key, user_input_name)) 511 exterior_argument_placeholders[user_input_name] = arg_placeholder 512 mapped_inputs.append(arg_placeholder) 513 return mapped_inputs, exterior_argument_placeholders 514 515 516def _call_function_with_mapped_captures(function, args, resource_map): 517 """Calls `function` in the exported graph, using mapped resource captures.""" 518 export_captures = _map_captures_to_created_tensors(function.graph.captures, 519 resource_map) 520 # Calls the function quite directly, since we have new captured resource 521 # tensors we need to feed in which weren't part of the original function 522 # definition. 523 # pylint: disable=protected-access 524 outputs = function._call_flat(args, export_captures) 525 # pylint: enable=protected-access 526 return outputs 527 528 529def _generate_signatures(signature_functions, resource_map): 530 """Validates and calls `signature_functions` in the default graph. 531 532 Args: 533 signature_functions: A dictionary mapping string keys to concrete TensorFlow 534 functions (e.g. from `signature_serialization.canonicalize_signatures`) 535 which will be used to generate SignatureDefs. 536 resource_map: A dictionary mapping from resource tensors in the eager 537 context to resource tensors in the Graph being exported. This dictionary 538 is used to re-bind resources captured by functions to tensors which will 539 exist in the SavedModel. 540 541 Returns: 542 Each function in the `signature_functions` dictionary is called with 543 placeholder Tensors, generating a function call operation and output 544 Tensors. The placeholder Tensors, the function call operation, and the 545 output Tensors from the function call are part of the default Graph. 546 547 This function then returns a dictionary with the same structure as 548 `signature_functions`, with the concrete functions replaced by SignatureDefs 549 implicitly containing information about how to call each function from a 550 TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference 551 the generated placeholders and Tensor outputs by name. 552 553 The caller is expected to include the default Graph set while calling this 554 function as a MetaGraph in a SavedModel, including the returned 555 SignatureDefs as part of that MetaGraph. 556 """ 557 signatures = {} 558 for signature_key, function in sorted(signature_functions.items()): 559 if function.graph.captures: 560 argument_inputs = function.graph.inputs[:-len(function.graph.captures)] 561 else: 562 argument_inputs = function.graph.inputs 563 mapped_inputs, exterior_argument_placeholders = ( 564 _map_function_arguments_to_created_inputs(argument_inputs, 565 signature_key, function.name)) 566 outputs = _call_function_with_mapped_captures( 567 function, mapped_inputs, resource_map) 568 signatures[signature_key] = signature_def_utils.build_signature_def( 569 _tensor_dict_to_tensorinfo(exterior_argument_placeholders), 570 _tensor_dict_to_tensorinfo(outputs), 571 method_name=signature_constants.PREDICT_METHOD_NAME) 572 return signatures 573 574 575def _trace_resource_initializers(accessible_objects): 576 """Create concrete functions from `CapturableResource` objects.""" 577 resource_initializers = [] 578 579 def _wrap_initializer(obj): 580 obj._initialize() # pylint: disable=protected-access 581 return constant_op.constant(1.) # Dummy control output 582 583 def _wrap_obj_initializer(obj): 584 return lambda: _wrap_initializer(obj) 585 586 for obj in accessible_objects: 587 if isinstance(obj, tracking.CapturableResource): 588 resource_initializers.append( 589 def_function.function( 590 _wrap_obj_initializer(obj), 591 # All inputs are captures. 592 input_signature=[]).get_concrete_function()) 593 return resource_initializers 594 595 596_AssetInfo = collections.namedtuple( 597 "_AssetInfo", 598 [ 599 # List of AssetFileDef protocol buffers 600 "asset_defs", 601 # Map from asset variable resource Tensors to their init ops 602 "asset_initializers_by_resource", 603 # Map from base asset filenames to full paths 604 "asset_filename_map", 605 # Map from Asset to index of corresponding AssetFileDef 606 "asset_index" 607 ]) 608 609 610def _process_asset(trackable_asset, asset_info, resource_map): 611 """Add `trackable_asset` to `asset_info` and `resource_map`.""" 612 original_path_tensor = trackable_asset.asset_path 613 original_path = tensor_util.constant_value(original_path_tensor) 614 try: 615 original_path = str(original_path.astype(str)) 616 except AttributeError: 617 # Already a string rather than a numpy array 618 pass 619 path = builder_impl.get_asset_filename_to_add( 620 asset_filepath=original_path, 621 asset_filename_map=asset_info.asset_filename_map) 622 # TODO(andresp): Instead of mapping 1-1 between trackable asset 623 # and asset in the graph def consider deduping the assets that 624 # point to the same file. 625 asset_path_initializer = array_ops.placeholder( 626 shape=original_path_tensor.shape, 627 dtype=dtypes.string, 628 name="asset_path_initializer") 629 asset_variable = resource_variable_ops.ResourceVariable( 630 asset_path_initializer) 631 asset_info.asset_filename_map[path] = original_path 632 asset_def = meta_graph_pb2.AssetFileDef() 633 asset_def.filename = path 634 asset_def.tensor_info.name = asset_path_initializer.name 635 asset_info.asset_defs.append(asset_def) 636 asset_info.asset_initializers_by_resource[original_path_tensor] = ( 637 asset_variable.initializer) 638 asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1 639 resource_map[original_path_tensor] = asset_variable 640 641 642def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, 643 namespace_whitelist): 644 """Generates a MetaGraph which calls `signature_functions`. 645 646 Args: 647 meta_graph_def: The MetaGraphDef proto to fill. 648 saveable_view: The _SaveableView being exported. 649 signature_functions: A dictionary mapping signature keys to concrete 650 functions containing signatures to add to the MetaGraph. 651 namespace_whitelist: List of strings containing whitelisted op namespaces. 652 653 Returns: 654 A tuple of (_AssetInfo, Graph) containing the captured assets and 655 exported Graph generated from tracing the saveable_view. 656 """ 657 # List objects from the eager context to make sure Optimizers give us the 658 # right Graph-dependent variables. 659 accessible_objects = saveable_view.nodes 660 resource_initializer_functions = _trace_resource_initializers( 661 accessible_objects) 662 exported_graph = ops.Graph() 663 resource_initializer_ops = [] 664 with exported_graph.as_default(): 665 object_map, resource_map, asset_info = saveable_view.map_resources() 666 for resource_initializer_function in resource_initializer_functions: 667 asset_dependencies = [] 668 for capture in resource_initializer_function.graph.external_captures: 669 asset_initializer = asset_info.asset_initializers_by_resource.get( 670 capture, None) 671 if asset_initializer is not None: 672 asset_dependencies.append(asset_initializer) 673 with ops.control_dependencies(asset_dependencies): 674 resource_initializer_ops.append( 675 _call_function_with_mapped_captures(resource_initializer_function, 676 [], resource_map)) 677 resource_initializer_ops.extend( 678 asset_info.asset_initializers_by_resource.values()) 679 with ops.control_dependencies(resource_initializer_ops): 680 init_op = control_flow_ops.no_op() 681 # Add the same op to the main_op collection and to the init_op 682 # signature. The collection is for compatibility with older loader APIs; 683 # only one will be executed. 684 meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( 685 init_op.name) 686 meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( 687 signature_def_utils.op_signature_def(init_op, 688 constants.INIT_OP_SIGNATURE_KEY)) 689 690 # Saving an object-based checkpoint again gathers variables. We need to do the 691 # gathering from the eager context so Optimizers save the right set of 692 # variables, but want any operations associated with the save/restore to be in 693 # the exported graph (thus the `to_graph` argument). 694 saver = functional_saver.MultiDeviceSaver( 695 saveable_view.checkpoint_view.frozen_saveable_objects( 696 object_map=object_map, to_graph=exported_graph, 697 call_with_mapped_captures=functools.partial( 698 _call_function_with_mapped_captures, resource_map=resource_map))) 699 700 with exported_graph.as_default(): 701 signatures = _generate_signatures(signature_functions, resource_map) 702 for concrete_function in saveable_view.concrete_functions: 703 concrete_function.add_to_graph() 704 saver_def = saver.to_proto() 705 meta_graph_def.saver_def.CopyFrom(saver_def) 706 graph_def = exported_graph.as_graph_def(add_shapes=True) 707 _verify_ops(graph_def, namespace_whitelist) 708 709 meta_graph_def.graph_def.CopyFrom(graph_def) 710 meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) 711 meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ 712 meta_graph_def.meta_info_def.tensorflow_git_version = ( 713 versions.__git_version__) 714 # We currently always strip default attributes. 715 meta_graph_def.meta_info_def.stripped_default_attrs = True 716 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 717 meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) 718 meta_graph_def.asset_file_def.extend(asset_info.asset_defs) 719 for signature_key, signature in signatures.items(): 720 meta_graph_def.signature_def[signature_key].CopyFrom(signature) 721 meta_graph.strip_graph_default_valued_attrs(meta_graph_def) 722 return asset_info, exported_graph 723 724 725def _verify_ops(graph_def, namespace_whitelist): 726 """Verifies that all namespaced ops in the graph are whitelisted.""" 727 invalid_ops = [] 728 invalid_namespaces = set() 729 730 all_operations = [] 731 all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def)) 732 733 for op in all_operations: 734 if ">" in op: 735 namespace = op.split(">")[0] 736 if namespace not in namespace_whitelist: 737 invalid_ops.append(op) 738 invalid_namespaces.add(namespace) 739 if invalid_ops: 740 raise ValueError( 741 "Attempted to save ops from non-whitelisted namespaces to SavedModel: " 742 "{}.\nPlease verify that these ops should be saved, since they must be " 743 "available when loading the SavedModel. If loading from Python, you " 744 "must import the library defining these ops. From C++, link the custom " 745 "ops to the serving binary. Once you've confirmed this, please add the " 746 "following namespaces to the `namespace_whitelist` argument in " 747 "tf.saved_model.SaveOptions: {}.".format(invalid_ops, 748 invalid_namespaces)) 749 750 751def _serialize_object_graph(saveable_view, asset_file_def_index): 752 """Save a SavedObjectGraph proto for `root`.""" 753 # SavedObjectGraph is similar to the TrackableObjectGraph proto in the 754 # checkpoint. It will eventually go into the SavedModel. 755 proto = saved_object_graph_pb2.SavedObjectGraph() 756 saveable_view.fill_object_graph_proto(proto) 757 758 coder = nested_structure_coder.StructureCoder() 759 for concrete_function in saveable_view.concrete_functions: 760 name = compat.as_text(concrete_function.name) 761 name = saveable_view.function_name_map.get(name, name) 762 serialized = function_serialization.serialize_concrete_function( 763 concrete_function, saveable_view.captured_tensor_node_ids, coder) 764 if serialized is not None: 765 proto.concrete_functions[name].CopyFrom(serialized) 766 767 saved_object_metadata = False 768 for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): 769 has_saved_object_metadata = _write_object_proto( 770 obj, obj_proto, asset_file_def_index, saveable_view.function_name_map) 771 saved_object_metadata = saved_object_metadata or has_saved_object_metadata 772 return proto, saved_object_metadata 773 774 775def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): 776 """Saves an object into SavedObject proto.""" 777 has_saved_object_metadata = False # The metadata field will be deprecated. 778 if isinstance(obj, tracking.Asset): 779 proto.asset.SetInParent() 780 proto.asset.asset_file_def_index = asset_file_def_index[obj] 781 elif resource_variable_ops.is_resource_variable(obj): 782 proto.variable.SetInParent() 783 if not obj.name.endswith(":0"): 784 raise ValueError("Cowardly refusing to save variable {} because of" 785 " unexpected suffix which won't be restored.".format( 786 obj.name)) 787 proto.variable.name = meta_graph._op_name(obj.name) # pylint: disable=protected-access 788 proto.variable.trainable = obj.trainable 789 proto.variable.dtype = obj.dtype.as_datatype_enum 790 proto.variable.synchronization = obj.synchronization.value 791 proto.variable.aggregation = obj.aggregation.value 792 proto.variable.shape.CopyFrom(obj.shape.as_proto()) 793 options = save_context.get_save_options() 794 if options.experimental_variable_policy._save_variable_devices( # pylint: disable=protected-access 795 ): 796 if hasattr(obj, "device"): 797 proto.variable.device = obj.device 798 elif isinstance(obj, def_function.Function): 799 proto.function.CopyFrom(function_serialization.serialize_function( 800 obj, function_name_map)) 801 elif isinstance(obj, defun.ConcreteFunction): 802 proto.bare_concrete_function.CopyFrom( 803 function_serialization.serialize_bare_concrete_function( 804 obj, function_name_map)) 805 elif isinstance(obj, _CapturedConstant): 806 proto.constant.operation = obj.graph_tensor.op.name 807 elif isinstance(obj, tracking.CapturableResource): 808 proto.resource.device = obj._resource_device # pylint: disable=protected-access 809 else: 810 registered_type_proto = revived_types.serialize(obj) 811 if registered_type_proto is None: 812 # Fallback for types with no matching registration 813 # pylint:disable=protected-access 814 metadata = obj._tracking_metadata 815 if metadata: 816 has_saved_object_metadata = True 817 registered_type_proto = saved_object_graph_pb2.SavedUserObject( 818 identifier=obj._object_identifier, 819 version=versions_pb2.VersionDef( 820 producer=1, min_consumer=1, bad_consumers=[]), 821 metadata=metadata) 822 # pylint:enable=protected-access 823 proto.user_object.CopyFrom(registered_type_proto) 824 825 # Give the object a chance to modify the SavedObject proto. 826 # This is currently used by MirroredVariables to optionally write their 827 # component variables to the proto. 828 # 829 # This is not yet an official Trackable method, the only current use case 830 # being MirroredVariables. See the method implementation there for more 831 # documentation. 832 if hasattr(obj, "_write_object_proto"): 833 obj._write_object_proto(proto, options) # pylint: disable=protected-access 834 return has_saved_object_metadata 835 836 837def _export_debug_info(exported_graph, export_dir): 838 """Exports debug information from graph to file. 839 840 Creates and writes GraphDebugInfo with traces for ops in all functions of the 841 exported_graph. 842 843 Args: 844 exported_graph: A Graph that has been created by tracing a saveable view. 845 export_dir: SavedModel directory in which to write the debug info. 846 """ 847 exported_operations = [] 848 for fn_name in exported_graph._functions: # pylint: disable=protected-access 849 fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access 850 if not isinstance(fn, defun._EagerDefinedFunction): # pylint: disable=protected-access 851 continue 852 853 fn_graph = fn.graph 854 for fn_op in fn_graph.get_operations(): 855 exported_operations.append((fn_name, fn_op)) 856 857 graph_debug_info = error_interpolation.create_graph_debug_info_def( 858 exported_operations) 859 file_io.atomic_write_string_to_file( 860 os.path.join( 861 utils_impl.get_or_create_debug_dir(export_dir), 862 constants.DEBUG_INFO_FILENAME_PB), 863 graph_debug_info.SerializeToString(deterministic=True)) 864 865 866@tf_export( 867 "saved_model.save", 868 v1=["saved_model.save", "saved_model.experimental.save"]) 869def save(obj, export_dir, signatures=None, options=None): 870 # pylint: disable=line-too-long 871 """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk). 872 873 The `obj` must inherit from the [`Trackable` class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591). 874 875 Example usage: 876 877 >>> class Adder(tf.Module): 878 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)]) 879 ... def add(self, x): 880 ... return x + x 881 882 >>> model = Adder() 883 >>> tf.saved_model.save(model, '/tmp/adder') 884 885 The resulting SavedModel is then servable with an input named "x", a scalar 886 with dtype float32. 887 888 _Signatures_ 889 890 Signatures define the input and output types for a computation. The optional 891 save `signatures` argument controls which methods in `obj` will be 892 available to programs which consume `SavedModel`s, for example, serving 893 APIs. Python functions may be decorated with 894 `@tf.function(input_signature=...)` and passed as signatures directly, or 895 lazily with a call to `get_concrete_function` on the method decorated with 896 `@tf.function`. 897 898 Example: 899 900 >>> class Adder(tf.Module): 901 ... @tf.function 902 ... def add(self, x): 903 ... return x + x 904 905 >>> model = Adder() 906 >>> tf.saved_model.save( 907 ... model, '/tmp/adder',signatures=model.add.get_concrete_function( 908 ... tf.TensorSpec([], tf.float32))) 909 910 If a `@tf.function` does not have an input signature and 911 `get_concrete_function` is not called on that method, the function will not 912 be directly callable in the restored SavedModel. 913 914 Example: 915 916 >>> class Adder(tf.Module): 917 ... @tf.function 918 ... def add(self, x): 919 ... return x + x 920 921 >>> model = Adder() 922 >>> tf.saved_model.save(model, '/tmp/adder') 923 >>> restored = tf.saved_model.load('/tmp/adder') 924 >>> restored.add(1.) 925 Traceback (most recent call last): 926 ... 927 ValueError: Found zero restored functions for caller function. 928 929 If the `signatures` argument is omitted, `obj` will be searched for 930 `@tf.function`-decorated methods. If exactly one traced `@tf.function` is 931 found, that method will be used as the default signature for the SavedModel. 932 Else, any `@tf.function` attached to `obj` or its dependencies will be 933 exported for use with `tf.saved_model.load`. 934 935 When invoking a signature in an exported SavedModel, `Tensor` arguments are 936 identified by name. These names will come from the Python function's argument 937 names by default. They may be overridden by specifying a `name=...` argument 938 in the corresponding `tf.TensorSpec` object. Explicit naming is required if 939 multiple `Tensor`s are passed through a single argument to the Python 940 function. 941 942 The outputs of functions used as `signatures` must either be flat lists, in 943 which case outputs will be numbered, or a dictionary mapping string keys to 944 `Tensor`, in which case the keys will be used to name outputs. 945 946 Signatures are available in objects returned by `tf.saved_model.load` as a 947 `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` 948 on an object with a custom `.signatures` attribute will raise an exception. 949 950 _Using `tf.saved_model.save` with Keras models_ 951 952 While Keras has its own [saving and loading API](https://www.tensorflow.org/guide/keras/save_and_serialize), 953 this function can be used to export Keras models. For example, exporting with 954 a signature specified: 955 956 >>> class Adder(tf.keras.Model): 957 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) 958 ... def concat(self, x): 959 ... return x + x 960 961 >>> model = Adder() 962 >>> tf.saved_model.save(model, '/tmp/adder') 963 964 Exporting from a function without a fixed signature: 965 966 >>> class Adder(tf.keras.Model): 967 ... @tf.function 968 ... def concat(self, x): 969 ... return x + x 970 971 >>> model = Adder() 972 >>> tf.saved_model.save( 973 ... model, '/tmp/adder', 974 ... signatures=model.concat.get_concrete_function( 975 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))) 976 977 `tf.keras.Model` instances constructed from inputs and outputs already have a 978 signature and so do not require a `@tf.function` decorator or a `signatures` 979 argument. If neither are specified, the model's forward pass is exported. 980 981 >>> x = tf.keras.layers.Input((4,), name="x") 982 >>> y = tf.keras.layers.Dense(5, name="out")(x) 983 >>> model = tf.keras.Model(x, y) 984 >>> tf.saved_model.save(model, '/tmp/saved_model/') 985 986 The exported SavedModel takes "x" with shape [None, 4] and returns "out" 987 with shape [None, 5] 988 989 _Variables and Checkpoints_ 990 991 Variables must be tracked by assigning them to an attribute of a tracked 992 object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers 993 from `tf.keras.layers`, optimizers from `tf.train`) track their variables 994 automatically. This is the same tracking scheme that `tf.train.Checkpoint` 995 uses, and an exported `Checkpoint` object may be restored as a training 996 checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's 997 "variables/" subdirectory. 998 999 `tf.function` does not hard-code device annotations from outside the function 1000 body, instead of using the calling context's device. This means for example 1001 that exporting a model that runs on a GPU and serving it on a CPU will 1002 generally work, with some exceptions: 1003 1004 * `tf.device` annotations inside the body of the function will be hard-coded 1005 in the exported model; this type of annotation is discouraged. 1006 * Device-specific operations, e.g. with "cuDNN" in the name or with 1007 device-specific layouts, may cause issues. 1008 * For `ConcreteFunctions`, active distribution strategies will cause device 1009 placements to be hard-coded in the function. 1010 1011 SavedModels exported with `tf.saved_model.save` [strip default-valued 1012 attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) 1013 automatically, which removes one source of incompatibilities when the consumer 1014 of a SavedModel is running an older TensorFlow version than the 1015 producer. There are however other sources of incompatibilities which are not 1016 handled automatically, such as when the exported model contains operations 1017 which the consumer does not have definitions for. 1018 1019 Args: 1020 obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export. 1021 export_dir: A directory in which to write the SavedModel. 1022 signatures: Optional, one of three types: 1023 * a `tf.function` with an input signature specified, which will use the 1024 default serving signature key, 1025 * the result of `f.get_concrete_function` on a `@tf.function`-decorated 1026 function `f`, in which case `f` will be used to generate a signature for 1027 the SavedModel under the default serving signature key, 1028 * a dictionary, which maps signature keys to either `tf.function` 1029 instances with input signatures or concrete functions. Keys of such a 1030 dictionary may be arbitrary strings, but will typically be from the 1031 `tf.saved_model.signature_constants` module. 1032 options: `tf.saved_model.SaveOptions` object for configuring save options. 1033 1034 Raises: 1035 ValueError: If `obj` is not trackable. 1036 1037 @compatibility(eager) 1038 Not well supported when graph building. From TensorFlow 1.x, 1039 `tf.compat.v1.enable_eager_execution()` should run first. Calling 1040 tf.saved_model.save in a loop when graph building from TensorFlow 1.x will 1041 add new save operations to the default graph each iteration. 1042 1043 May not be called from within a function body. 1044 @end_compatibility 1045 """ 1046 # pylint: enable=line-too-long 1047 save_and_return_nodes(obj, export_dir, signatures, options, 1048 raise_metadata_warning=True) 1049 1050 1051def save_and_return_nodes(obj, 1052 export_dir, 1053 signatures=None, 1054 options=None, 1055 raise_metadata_warning=False, 1056 experimental_skip_checkpoint=False): 1057 """Saves a SavedModel while returning all saved nodes and their paths. 1058 1059 Please see `tf.saved_model.save` for details. 1060 1061 Args: 1062 obj: A trackable object to export. 1063 export_dir: A directory in which to write the SavedModel. 1064 signatures: A function or dictionary of functions to save in the SavedModel 1065 as signatures. 1066 options: `tf.saved_model.SaveOptions` object for configuring save options. 1067 raise_metadata_warning: Whether to raise the metadata warning. This arg will 1068 be removed in TF 2.5. 1069 experimental_skip_checkpoint: If set to `True`, the checkpoint will not 1070 be written. 1071 1072 Returns: 1073 A tuple of (a list of saved nodes in the order they are serialized to the 1074 `SavedObjectGraph`, dictionary mapping nodes to one possible path from 1075 the root node to the key node) 1076 """ 1077 options = options or save_options.SaveOptions() 1078 # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x 1079 # compatible (no sessions) and share it with this export API rather than 1080 # making a SavedModel proto and writing it directly. 1081 saved_model = saved_model_pb2.SavedModel() 1082 meta_graph_def = saved_model.meta_graphs.add() 1083 1084 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( 1085 _build_meta_graph(obj, signatures, options, meta_graph_def, 1086 raise_metadata_warning)) 1087 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION 1088 1089 # Write the checkpoint, copy assets into the assets directory, and write out 1090 # the SavedModel proto itself. 1091 if not experimental_skip_checkpoint: 1092 utils_impl.get_or_create_variables_dir(export_dir) 1093 ckpt_options = checkpoint_options.CheckpointOptions( 1094 experimental_io_device=options.experimental_io_device) 1095 object_saver.save( 1096 utils_impl.get_variables_path(export_dir), options=ckpt_options) 1097 builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, 1098 export_dir) 1099 # Note that this needs to be the last file operation when saving the 1100 # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an 1101 # indication that the SavedModel is completely written. 1102 if context.executing_eagerly(): 1103 try: 1104 context.async_wait() # Ensure save operations have completed. 1105 except errors.NotFoundError as err: 1106 raise FileNotFoundError( 1107 str(err) + "\n If trying to save on a different device from the " 1108 "computational device, consider using setting the " 1109 "`experimental_io_device` option on tf.saved_model.SaveOptions " 1110 "to the io_device such as '/job:localhost'." 1111 ) 1112 1113 path = os.path.join( 1114 compat.as_str(export_dir), 1115 compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) 1116 file_io.atomic_write_string_to_file( 1117 path, saved_model.SerializeToString(deterministic=True)) 1118 # Save debug info, if requested. 1119 if options.save_debug_info: 1120 _export_debug_info(exported_graph, export_dir) 1121 1122 # Clean reference cycles so repeated export()s don't make work for the garbage 1123 # collector. Before this point, we need to keep references to captured 1124 # constants in the saved graph. 1125 ops.dismantle_graph(exported_graph) 1126 1127 return saved_nodes, node_paths 1128 1129 1130def export_meta_graph(obj, filename, signatures=None, options=None): 1131 """Exports the MetaGraph proto of the `obj` to a file. 1132 1133 This function goes through the same procedures saved_model.save goes to 1134 produce the given object's MetaGraph, then saves it to the given file. It 1135 skips saving checkpoint information, and is useful when all one wants is the 1136 graph defining the model. 1137 1138 Args: 1139 obj: A trackable object to build the MetaGraph from. 1140 filename: The file into which to write the MetaGraph. 1141 signatures: Optional, either a `tf.function` with an input signature 1142 specified or the result of `f.get_concrete_function` on a 1143 `@tf.function`-decorated function `f`, in which case `f` will be used to 1144 generate a signature for the SavedModel under the default serving 1145 signature key. `signatures` may also be a dictionary, in which case it 1146 maps from signature keys to either `tf.function` instances with input 1147 signatures or concrete functions. The keys of such a dictionary may be 1148 arbitrary strings, but will typically be from the 1149 `tf.saved_model.signature_constants` module. 1150 options: Optional, `tf.saved_model.SaveOptions` object that specifies 1151 options for saving. 1152 """ 1153 options = options or save_options.SaveOptions() 1154 export_dir = os.path.dirname(filename) 1155 meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( 1156 obj, signatures, options) 1157 1158 file_io.atomic_write_string_to_file( 1159 filename, meta_graph_def.SerializeToString(deterministic=True)) 1160 1161 # Save debug info, if requested. 1162 if options.save_debug_info: 1163 _export_debug_info(exported_graph, export_dir) 1164 1165 # Clean reference cycles so repeated export()s don't make work for the garbage 1166 # collector. Before this point, we need to keep references to captured 1167 # constants in the saved graph. 1168 ops.dismantle_graph(exported_graph) 1169 1170 1171def _build_meta_graph_impl(obj, 1172 signatures, 1173 options, 1174 meta_graph_def=None, 1175 raise_metadata_warning=True): 1176 """Creates a MetaGraph containing the resources and functions of an object.""" 1177 if ops.inside_function(): 1178 raise AssertionError( 1179 "tf.saved_model.save is not supported inside a traced @tf.function. " 1180 "Move the call to the outer eagerly-executed context.") 1181 # pylint: enable=line-too-long 1182 if not isinstance(obj, base.Trackable): 1183 raise ValueError( 1184 "Expected a Trackable object for export, got {}.".format(obj)) 1185 meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() 1186 1187 checkpoint_graph_view = _AugmentedGraphView(obj) 1188 if signatures is None: 1189 signatures = signature_serialization.find_function_to_export( 1190 checkpoint_graph_view) 1191 1192 signatures, wrapped_functions = ( 1193 signature_serialization.canonicalize_signatures(signatures)) 1194 signature_serialization.validate_saveable_view(checkpoint_graph_view) 1195 signature_map = signature_serialization.create_signature_map(signatures) 1196 checkpoint_graph_view.add_object( 1197 parent_node=checkpoint_graph_view.root, 1198 name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, 1199 subgraph_root=signature_map) 1200 1201 # Use _SaveableView to provide a frozen listing of properties and functions. 1202 saveable_view = _SaveableView(checkpoint_graph_view, options, 1203 wrapped_functions) 1204 object_saver = util.TrackableSaver(checkpoint_graph_view) 1205 asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def, 1206 saveable_view, signatures, 1207 options.namespace_whitelist) 1208 if options.function_aliases: 1209 function_aliases = meta_graph_def.meta_info_def.function_aliases 1210 for alias, func in options.function_aliases.items(): 1211 for fdef in func._stateful_fn._function_cache.all_values(): # pylint: disable=protected-access 1212 function_aliases[fdef.name] = alias 1213 for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access 1214 function_aliases[fdef.name] = alias 1215 1216 object_graph_proto, saved_object_metadata = _serialize_object_graph( 1217 saveable_view, asset_info.asset_index) 1218 meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) 1219 1220 if saved_object_metadata and raise_metadata_warning: 1221 tf_logging.warn( 1222 'FOR KERAS USERS: The object that you are saving contains one or more ' 1223 'Keras models or layers. If you are loading the SavedModel with ' 1224 '`tf.keras.models.load_model`, continue reading (otherwise, you may ' 1225 'ignore the following instructions). Please change your code to save ' 1226 'with `tf.keras.models.save_model` or `model.save`, and confirm that ' 1227 'the file "keras.metadata" exists in the export directory. In the ' 1228 'future, Keras will only load the SavedModels that have this file. In ' 1229 'other words, `tf.saved_model.save` will no longer write SavedModels ' 1230 'that can be recovered as Keras models (this will apply in TF 2.5).' 1231 '\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,' 1232 ' this property has been used to save metadata in the SavedModel. The ' 1233 'metadta field will be deprecated soon, so please move the metadata to ' 1234 'a different file.') 1235 1236 return (meta_graph_def, exported_graph, object_saver, asset_info, 1237 saveable_view.nodes, saveable_view.node_paths) 1238 1239 1240def _build_meta_graph(obj, 1241 signatures, 1242 options, 1243 meta_graph_def=None, 1244 raise_metadata_warning=True): 1245 """Creates a MetaGraph under a save context. 1246 1247 Args: 1248 obj: A trackable object to build the MetaGraph from. 1249 signatures: Can be a `tf.function` with an input signature specified or the 1250 result of `f.get_concrete_function` on a `@tf.function`-decorated function 1251 `f`. `signatures` may also be a dictionary, in which case it maps from 1252 signature keys to `tf.function` instances. If None, finds signature to 1253 export from the `@tf.function`-decorated methods in `obj`. 1254 options: `tf.saved_model.SaveOptions` object that specifies options for 1255 saving. 1256 meta_graph_def: Optional, the MetaGraphDef proto fill. 1257 raise_metadata_warning: Whether to raise a warning when user objects contain 1258 non-empty metadata. 1259 1260 Raises: 1261 AssertionError: If `export_meta_graph` is executing inside a `tf.function`. 1262 ValueError: If `obj` is not trackable. 1263 1264 Returns: 1265 meta_graph_def: Filled MetaGraphDef proto 1266 exported_graph: `tf.Graph` object generated from `obj`. 1267 object_saver: `util.TrackableSaver` of the `obj` and its dependencies. 1268 asset_info: `_AssetInfo` tuple containing external assets in the `obj`. 1269 """ 1270 1271 with save_context.save_context(options): 1272 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def, 1273 raise_metadata_warning) 1274