1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Import a trackable object from a SavedModel.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23import sys 24 25from tensorflow.core.protobuf import graph_debug_info_pb2 26from tensorflow.python.distribute import distribute_utils 27from tensorflow.python.distribute import distribution_strategy_context as ds_context 28from tensorflow.python.distribute import values_util 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.eager import function 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import handle_data_util 40from tensorflow.python.ops import lookup_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import variables 43from tensorflow.python.saved_model import function_deserialization 44from tensorflow.python.saved_model import load_options 45from tensorflow.python.saved_model import load_v1_in_v2 46from tensorflow.python.saved_model import loader_impl 47from tensorflow.python.saved_model import nested_structure_coder 48from tensorflow.python.saved_model import revived_types 49from tensorflow.python.saved_model import utils_impl as saved_model_utils 50from tensorflow.python.saved_model.pywrap_saved_model import metrics 51from tensorflow.python.training.saving import checkpoint_options 52from tensorflow.python.training.saving import saveable_object_util 53from tensorflow.python.training.tracking import base 54from tensorflow.python.training.tracking import data_structures 55from tensorflow.python.training.tracking import graph_view 56from tensorflow.python.training.tracking import tracking 57from tensorflow.python.training.tracking import util 58from tensorflow.python.util import nest 59from tensorflow.python.util.tf_export import tf_export 60 61# API label for SavedModel metrics. 62_LOAD_V2_LABEL = "load_v2" 63 64 65def _unused_handle(): 66 """Returns a placeholder as a handle that is not supposed to be accessed.""" 67 error_message = ("Trying to access a placeholder that is not supposed to be " 68 "executed. This means you are executing a graph generated " 69 "from the cross-replica context in an in-replica context.") 70 71 assert_op = control_flow_ops.Assert( 72 array_ops.placeholder_with_default(False, shape=()), 73 [error_message]) 74 75 with ops.control_dependencies([assert_op]): 76 return array_ops.placeholder(dtype=dtypes.resource) 77 78 79class _WrapperFunction(function.ConcreteFunction): 80 """A class wraps a concrete function to handle different distributed contexts. 81 82 The reason for wrapping a concrete function is because the _captured_inputs 83 fields used for in-replica context and cross-replica context are different. 84 When `load()` is called from within a tf.distribute.strategy scope, the 85 captured inputs are distributed variables. When using these distributed 86 variables during calling the function, we need different approaches when it is 87 in-replica and when it is not in-replica. When it is in replica, naturally we 88 should use the corresponding component of the distributed variable; when it is 89 not in-replica, calling the function should mean that it is constructing a 90 graph that is not actually going to be used. A typical use case is when 91 constructing a functional model. In this case, return a placeholder with a 92 control dependency to ensure that is never accessed. 93 """ 94 95 def __init__(self, concrete_function): 96 # Shallow copy the concrete_function 97 self.__dict__.update(vars(concrete_function)) 98 99 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 100 101 def get_handle(x): 102 return x.handle if distribute_utils.is_distributed_variable(x) else x 103 104 def get_unused_handle(x): 105 return _unused_handle() if distribute_utils.is_distributed_variable(x) \ 106 else x 107 108 if (ds_context.get_replica_context() is not None or 109 values_util.is_saving_non_distributed()): 110 # If we're in the replica context or are saving a non-distributed version 111 # of the model, we resolve the captured variables to the corresponding 112 # resource handle. In both situation we call var.handle, but it has 113 # different behavior. In the replica context, var.handle resolves the 114 # replica local variable handle if the variable is replicated. When saving 115 # a non-distributed version of the model, var.handle resolves to the 116 # primary variable handle, since we only save one copy of a replicated 117 # variable. 118 captured_inputs = list(map(get_handle, captured_inputs)) 119 else: # cross-replica context 120 captured_inputs = list(map(get_unused_handle, captured_inputs)) 121 return super(_WrapperFunction, self)._call_flat(args, captured_inputs, 122 cancellation_manager) 123 124 125class Loader(object): 126 """Helper class to load an object-based SavedModel.""" 127 128 def __init__(self, object_graph_proto, saved_model_proto, export_dir, 129 ckpt_options, save_options, filters): 130 meta_graph = saved_model_proto.meta_graphs[0] 131 self._asset_file_def = meta_graph.asset_file_def 132 self._operation_attributes = { 133 node.name: node.attr for node in meta_graph.graph_def.node} 134 self._proto = object_graph_proto 135 self._export_dir = export_dir 136 self._concrete_functions = ( 137 function_deserialization.load_function_def_library( 138 meta_graph.graph_def.library, wrapper_function=_WrapperFunction)) 139 self._checkpoint_options = ckpt_options 140 self._save_options = save_options 141 142 # Stores user-defined node_filters argument. 143 self._node_filters = filters 144 # Stores map of string paths to integers. 145 self._node_path_to_id = self._convert_node_paths_to_ints() 146 self._loaded_nodes = {} 147 if isinstance(filters, dict): 148 # If node_filters is a dict, then the values may contain already created 149 # trackable objects. In this case, create a dictionary mapping node IDs to 150 # the already created nodes. This dict will be updated in 151 # `_retrieve_all_filtered_nodes` with tracked dependencies. 152 for node_path, node in filters.items(): 153 if isinstance(node, tuple): 154 self._loaded_nodes[self._node_path_to_id[node_path]] = node 155 else: 156 self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr) 157 158 # Get a list of all integer node ids to load, or None if all nodes should be 159 # loaded. This list includes ids of child nodes. 160 self._filtered_nodes = self._retrieve_all_filtered_nodes() 161 162 self._load_all() 163 164 if not save_options.experimental_skip_checkpoint: 165 self._restore_checkpoint() 166 for node in self._nodes: 167 if isinstance(node, tracking.CapturableResource): 168 init_op = node._initialize() # pylint: disable=protected-access 169 if not context.executing_eagerly(): 170 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 171 172 def _convert_node_paths_to_ints(self): 173 """Maps all string node paths in node_filters to the int node ids.""" 174 if self._node_filters is None: 175 return None 176 path_to_int = {} 177 for node_id in self._node_filters: 178 int_node_id = None 179 if isinstance(node_id, str): 180 node_path = node_id.split(".") 181 if node_path[0] != "root": 182 raise ValueError( 183 "When passing string identifiers to node_filters, the first name" 184 f" must be root. Received {node_path[0]}.") 185 int_node_id = 0 186 for n, name in enumerate(node_path[1:]): 187 int_node_id = self._find_node_child( 188 int_node_id, name, ".".join(node_path[:n+2])) 189 path_to_int[node_id] = int_node_id 190 else: 191 raise TypeError("Elements in node_filters must be strings.") 192 return path_to_int 193 194 def _retrieve_all_filtered_nodes(self): 195 """Traverses through the object graph to get the IDs of all nodes to load. 196 197 As a side-effect, if node_filters is a dictionary that contains already- 198 created objects, then the dependencies tracked by those objects will be 199 added to node_filters. 200 201 Returns: 202 List of all nodes to load, or None if all nodes should be loaded. 203 204 """ 205 if self._node_filters is None: 206 return None # All nodes should be loaded. 207 208 all_filtered_nodes = set() 209 nodes_to_visit = list(self._node_filters) 210 211 while nodes_to_visit: 212 node_path = nodes_to_visit.pop(0) 213 node_id = self._node_path_to_id[node_path] 214 if node_id in all_filtered_nodes: 215 continue 216 all_filtered_nodes.add(node_id) 217 218 node, setter = self._loaded_nodes.get(node_id, (None, None)) 219 if node is not None: 220 if not isinstance(node, base.Trackable): 221 raise TypeError( 222 "Error when processing dictionary values passed to nodes_to_load." 223 f"Object at {node_path} is expected to be a checkpointable (i.e. " 224 "'trackable') TensorFlow object (e.g. tf.Variable, tf.Module or " 225 "Keras layer).") 226 node._maybe_initialize_trackable() # pylint: disable=protected-access 227 228 for reference in self._proto.nodes[node_id].children: 229 child_object, _ = self._loaded_nodes.get( 230 reference.node_id, (None, None)) 231 232 # See if node already tracks the child reference, in which case add the 233 # child to the loaded_nodes dict. 234 if child_object is None and node is not None: 235 child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access 236 if isinstance(child_object, data_structures.TrackableDataStructure): 237 # Make setattr a noop to avoid overwriting already existing data 238 # structures. 239 setter = lambda *args: None 240 241 self._loaded_nodes[reference.node_id] = (child_object, setter) 242 243 child_path = "{}.{}".format(node_path, reference.local_name) 244 self._node_path_to_id[child_path] = reference.node_id 245 nodes_to_visit.append(child_path) 246 247 if 0 in all_filtered_nodes: 248 return None 249 return all_filtered_nodes 250 251 def _find_node_child(self, node_id, child_name, path): 252 for reference in self._proto.nodes[node_id].children: 253 if reference.local_name == child_name: 254 return reference.node_id 255 raise ValueError(f"Unable to find node {path}.") 256 257 def _load_all(self): 258 """Loads all nodes and functions from the SavedModel and their edges.""" 259 self._load_nodes() 260 self._load_edges() 261 # TODO(b/124045874): There are limitations with functions whose captures 262 # trigger other functions to be executed. For now it is only guaranteed to 263 # work if the captures of a function only trigger functions without 264 # captures. 265 self._setup_functions_structures() 266 self._setup_functions_captures() 267 268 self._create_saveable_object_factories() 269 270 def _create_saveable_object_factories(self): 271 for node_id, proto in self._iter_all_nodes(): 272 node = self.get(node_id) 273 node._self_saveable_object_factories = {} # pylint: disable=protected-access 274 for name, saveable_object_proto in proto.saveable_objects.items(): 275 node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access 276 saveable_object_util.restored_saved_object_factory( 277 self.get(saveable_object_proto.save_function), 278 self.get(saveable_object_proto.restore_function))) 279 280 def _load_edges(self): 281 """Adds edges from objects to other objects and functions.""" 282 for node_id, object_proto in self._iter_all_nodes(): 283 self._add_object_graph_edges(object_proto, node_id) 284 285 # If root object isn't loaded, then create edges from the root for 286 # checkpoint compatibility. 287 if self._filtered_nodes is not None and 0 not in self._filtered_nodes: 288 root = self.get(0) 289 for node_path in self._node_filters: 290 loaded_node = self._nodes[self._node_path_to_id[node_path]] 291 path = node_path.split(".") 292 current_node = root 293 for name in path[1:-1]: 294 if not hasattr(current_node, name): 295 setattr(current_node, name, self._recreate_base_user_object()[0]) 296 current_node = getattr(current_node, name) 297 if not hasattr(current_node, path[-1]): 298 setattr(current_node, path[-1], loaded_node) 299 300 def _add_object_graph_edges(self, proto, node_id): 301 """Adds edges from an object to its children.""" 302 obj = self._nodes[node_id] 303 setter = self._node_setters[node_id] 304 305 for reference in proto.children: 306 setter(obj, reference.local_name, self._nodes[reference.node_id]) 307 # Note: if an object has an attribute `__call__` add a class method 308 # that allows `obj()` syntax to work. This is done per-instance to 309 # allow `callable` to be used to find out if an object is callable. 310 if reference.local_name == "__call__" and not callable(obj): 311 setattr(type(obj), "__call__", _call_attribute) 312 313 def _setup_functions_structures(self): 314 """Setup structure for inputs and outputs of restored functions.""" 315 coder = nested_structure_coder.StructureCoder() 316 for name, proto in sorted(self._proto.concrete_functions.items()): 317 concrete_function = self._concrete_functions[name] 318 # By setting the structured_outputs directly, we can rely on this 319 # function_lib.ConcreteFunction object to perform the output repacking 320 # logic. The only limitation of that logic is that it only works 321 # with output that is convertible to Tensors and the conversion 322 # always happens. For example tf.TensorShape([2, 3]) will be 323 # converted to Tensor representing [2, 3]. 324 original_outputs = coder.decode_proto(proto.output_signature) 325 # The original_outputs here had Tensors converted to TensorSpecs, so 326 # the restored function's structured_outputs field will not be 327 # exactly the same. Fortunately the repacking logic cares only about 328 # the structure; and the unpacking logic cares only about structure 329 # and types. 330 concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access 331 concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access 332 coder.decode_proto(proto.canonicalized_input_signature)) 333 concrete_function._initialize_function_spec() # pylint: disable=protected-access 334 335 def _setup_functions_captures(self): 336 """Setup captures and variables in restored functions.""" 337 concrete_functions = sorted(self._proto.concrete_functions.items()) 338 for name, proto in concrete_functions: 339 concrete_function = self._concrete_functions[name] 340 bound_inputs = [ 341 self._get_tensor_from_node(node_id, name) 342 for node_id in proto.bound_inputs] 343 bound_variables = [ 344 self._nodes[node_id] 345 for node_id in proto.bound_inputs 346 if self._proto.nodes[node_id].WhichOneof("kind") == "variable" 347 ] 348 # TODO(andresp): This is only injecting the captured inputs into the 349 # concrete function, note that we did not modify the FuncGraph 350 # itself. 351 concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access 352 concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access 353 if bound_inputs: 354 for bound_input, internal_capture in zip( 355 bound_inputs, concrete_function.inputs[-len(bound_inputs):]): 356 if distribute_utils.is_distributed_variable(bound_input): 357 concrete_function.graph.capture_distributed_variable( 358 bound_input, internal_capture) 359 else: 360 concrete_function.graph.replace_capture(bound_input, 361 internal_capture) 362 if internal_capture.dtype == dtypes.resource: 363 if resource_variable_ops.is_resource_variable(bound_input): 364 try: 365 handle = bound_input.handle 366 except ValueError: 367 # For mirrored variables we'll copy handle data for components 368 # as they get captured. 369 pass 370 else: 371 handle_data_util.copy_handle_data(handle, internal_capture) 372 else: 373 handle_data_util.copy_handle_data(bound_input, internal_capture) 374 # Setting "captures" first means "capture" won't create a new 375 # placeholder for this input. 376 concrete_function.graph.capture(bound_input) 377 378 def _get_tensor_from_node(self, node_id, fn_name): 379 """Resolves a node id into a tensor to be captured for a function.""" 380 if self._node_filters is not None and self._nodes[node_id] is None: 381 raise ValueError( 382 f"Error when processing nodes_to_load. Function '{fn_name}' requires " 383 "inputs/variables that are not loaded when nodes_to_load=" 384 f"{self._node_filters}.") 385 386 with ops.init_scope(): 387 obj = self._nodes[node_id] 388 if distribute_utils.is_distributed_variable(obj): 389 return obj 390 elif resource_variable_ops.is_resource_variable(obj): 391 return obj.handle 392 elif isinstance(obj, tracking.Asset): 393 return obj.asset_path 394 elif tensor_util.is_tf_type(obj): 395 return obj 396 elif isinstance(obj, tracking.CapturableResource): 397 # Note: this executes restored functions in the CapturableResource. 398 return obj.resource_handle 399 raise ValueError(f"Cannot convert node {obj} to tensor.") 400 401 def _initialize_loaded_nodes(self): 402 nodes = {} 403 node_setters = {} 404 for node_id, (node, setter) in self._loaded_nodes.items(): 405 nodes[node_id] = node 406 node_setters[node_id] = setter 407 return nodes, node_setters 408 409 def _iter_all_nodes(self): 410 if self._filtered_nodes is None: 411 return enumerate(self._proto.nodes) 412 else: 413 return [(node_id, self._proto.nodes[node_id]) 414 for node_id in self._filtered_nodes] 415 416 def _load_nodes(self): 417 """Load all saved objects.""" 418 # `nodes` maps from node ids to recreated objects 419 # `node_setters` maps from node ids to setter functions 420 # (same signature as setattr) for setting dependencies. 421 nodes, node_setters = self._initialize_loaded_nodes() 422 423 # Figure out which objects are slot variables. These objects are created 424 # with Optimizer.add_slot rather than _recreate_variable. 425 slot_variable_node_ids = set() 426 427 for _, proto in self._iter_all_nodes(): 428 for slot_variable_proto in proto.slot_variables: 429 slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id) 430 431 # Re-create everything except slot variables. 432 for node_id, proto in self._iter_all_nodes(): 433 if node_id in slot_variable_node_ids or nodes.get(node_id) is not None: 434 # Defer recreating slot variables so we can use the public Optimizer 435 # interface. 436 continue 437 node, setter = self._recreate(proto, node_id) 438 nodes[node_id] = node 439 node_setters[node_id] = setter 440 441 # Now that we have created the variables being optimized, we have enough 442 # information to re-create slot variables for them. 443 for node_id, proto in self._iter_all_nodes(): 444 optimizer_object = nodes[node_id] 445 for slot_variable_proto in proto.slot_variables: 446 optimized_variable = nodes[ 447 slot_variable_proto.original_variable_node_id] 448 slot_variable = optimizer_object.add_slot( 449 var=optimized_variable, 450 slot_name=slot_variable_proto.slot_name) 451 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable 452 node_setters[slot_variable_proto.slot_variable_node_id] = setattr 453 454 # If root object is not loaded, add a dummy root object for checkpoint 455 # compatibility. 456 if 0 not in nodes: 457 nodes[0] = self._recreate_base_user_object()[0] 458 459 self._nodes = [nodes.get(node_id) 460 for node_id in range(len(self._proto.nodes))] 461 self._node_setters = node_setters 462 463 def _restore_checkpoint(self): 464 """Load state from checkpoint into the deserialized objects.""" 465 variables_path = saved_model_utils.get_variables_path(self._export_dir) 466 # TODO(andresp): Clean use of private methods of TrackableSaver. 467 # pylint: disable=protected-access 468 saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) 469 with ops.device("CPU"): 470 saver._file_prefix_placeholder = constant_op.constant(variables_path) 471 if self._save_options.allow_partial_checkpoint: 472 load_status = saver.restore(variables_path, 473 self._checkpoint_options).expect_partial() 474 load_status.assert_nontrivial_match() 475 else: 476 load_status = saver.restore(variables_path, self._checkpoint_options) 477 load_status.assert_existing_objects_matched() 478 checkpoint = load_status._checkpoint 479 480 if not context.executing_eagerly(): 481 # When running in eager mode, the `restore` call above has already run and 482 # restored the state of trackables, and calling `position.restore_ops()` 483 # would re-run the restore. In graph mode, that will return a cached list 484 # of ops that must run to restore the object on that position. We have to 485 # wire them in the initializers of the objects so that they get 486 # initialized properly when using common practices (e.g. the ones used by 487 # ManagedSession) without further user action. 488 for object_id, obj in dict(checkpoint.object_by_proto_id).items(): 489 position = base.CheckpointPosition(checkpoint=checkpoint, 490 proto_id=object_id) 491 restore_ops = position.restore_ops() 492 if restore_ops: 493 if resource_variable_ops.is_resource_variable(obj): 494 if len(restore_ops) == 1: 495 obj._initializer_op = restore_ops[0] 496 else: 497 obj._initializer_op = control_flow_ops.group(*restore_ops) 498 elif isinstance(obj, lookup_ops.LookupInterface): 499 # We don't need to check for eager execution here, since this code 500 # path should only be taken if we are restoring in graph mode. 501 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) 502 else: 503 raise NotImplementedError( 504 f"Unable to restore state of object {obj} from the checkpoint.") 505 506 def adjust_debug_info_func_names(self, debug_info): 507 """Rewrite func names in the debug info by using the concrete func names.""" 508 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 509 output_debug_info.files[:] = debug_info.files 510 for key in debug_info.traces: 511 node, func = key.split("@") 512 new_func = "" 513 if func in self._concrete_functions: 514 new_func = self._concrete_functions[func].function_def.signature.name 515 output_debug_info.traces[node + "@" + new_func].CopyFrom( 516 debug_info.traces[key]) 517 return output_debug_info 518 519 def get(self, node_id): 520 if isinstance(node_id, str): 521 node_id = self._node_path_to_id[node_id] 522 return self._nodes[node_id] 523 524 def _recreate(self, proto, node_id): 525 """Creates a Python object from a SavedObject protocol buffer.""" 526 factory = { 527 "user_object": ( 528 lambda: self._recreate_user_object(proto.user_object, node_id)), 529 "asset": lambda: self._recreate_asset(proto.asset), 530 "function": lambda: self._recreate_function(proto.function), 531 "bare_concrete_function": functools.partial( 532 self._recreate_bare_concrete_function, 533 proto.bare_concrete_function), 534 "variable": lambda: self._recreate_variable(proto.variable), 535 "constant": lambda: self._recreate_constant(proto.constant), 536 "resource": lambda: self._recreate_resource(proto.resource), 537 "captured_tensor": functools.partial( 538 self._get_tensor_from_fn, proto.captured_tensor), 539 } 540 kind = proto.WhichOneof("kind") 541 if kind not in factory: 542 raise ValueError(f"Unknown SavedObject type: {kind}. Expected one of " 543 f"{list(factory.keys())}.") 544 return factory[kind]() 545 546 def _recreate_user_object(self, proto, node_id): 547 """Instantiates a SavedUserObject.""" 548 looked_up = revived_types.deserialize(proto) 549 if looked_up is None: 550 return self._recreate_base_user_object(proto, node_id) 551 return looked_up 552 553 def _recreate_base_user_object(self, proto=None, node_id=None): 554 del proto, node_id 555 # Note: each user object has its own class. This allows making each one 556 # individually callable by adding a `__call__` method to the classes of 557 # the objects instances that have a `__call__` property. 558 559 class _UserObject(tracking.AutoTrackable): 560 pass 561 562 return _UserObject(), setattr 563 564 def _recreate_asset(self, proto): 565 filename = os.path.join( 566 saved_model_utils.get_assets_dir(self._export_dir), 567 self._asset_file_def[proto.asset_file_def_index].filename) 568 asset = tracking.Asset(filename) 569 if not context.executing_eagerly(): 570 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path) 571 return asset, setattr 572 573 def _recreate_function(self, proto): 574 return function_deserialization.recreate_function( 575 proto, self._concrete_functions), setattr 576 577 def _recreate_bare_concrete_function(self, proto): 578 return function_deserialization.setup_bare_concrete_function( 579 proto, self._concrete_functions), setattr 580 581 def _recreate_variable(self, proto): 582 name = proto.name if proto.name else None 583 if name is not None: 584 dbg_name = name 585 else: 586 dbg_name = "<variable loaded from saved model>" 587 synchronization, aggregation, trainable = ( 588 variables.validate_synchronization_aggregation_trainable( 589 proto.synchronization, proto.aggregation, proto.trainable, 590 name=dbg_name)) 591 592 def uninitialized_variable_creator(next_creator, **kwargs): 593 """A variable creator that creates uninitialized variables.""" 594 del next_creator 595 return resource_variable_ops.UninitializedVariable(**kwargs) 596 597 # Create a variable_creator_scope that creates uninitialized variables with 598 # a lower priority such that a potential distributed variable_creator_scope 599 # can take precedence. 600 with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access 601 uninitialized_variable_creator, 602 priority=50): 603 return variables.Variable( 604 shape=proto.shape, 605 dtype=proto.dtype, 606 name=name, 607 trainable=trainable, 608 synchronization=synchronization, 609 aggregation=aggregation), setattr 610 611 def _recreate_constant(self, proto): 612 tensor_proto = self._operation_attributes[proto.operation]["value"].tensor 613 ndarray = tensor_util.MakeNdarray(tensor_proto) 614 if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string: 615 with ops.device("CPU"): 616 imported_constant = constant_op.constant(ndarray) 617 else: 618 imported_constant = constant_op.constant(ndarray) 619 return imported_constant, setattr 620 621 def _get_tensor_from_fn(self, proto): 622 outer_graph = self._concrete_functions[proto.concrete_function].graph 623 captured_tensor = outer_graph.get_tensor_by_name(proto.name) 624 return captured_tensor, setattr 625 626 def _recreate_resource(self, proto): 627 return _RestoredResource(device=proto.device), _setattr_and_track 628 629 630# TODO(b/124205571,b/124092991): Solve destruction of resources. 631class _RestoredResource(tracking.TrackableResource): 632 """Restored SavedResource.""" 633 634 def __init__(self, device=""): 635 super(_RestoredResource, self).__init__(device=device) 636 637 def _create_resource(self): 638 raise RuntimeError() 639 640 def _initialize(self): 641 raise RuntimeError() 642 643 # _list_functions_for_serialization expects Function objects, but unlike 644 # _create_resource and _initialize, _destroy_function didn't always exist in 645 # older TrackableResource implementations, so this default stub must be a 646 # Function. 647 @def_function.function 648 def _destroy_resource(self): 649 raise RuntimeError() 650 651 def _list_functions_for_serialization(self, unused_serialization_cache): 652 # Overwrite this method to avoid the implementation of 653 # base class to re-wrap the polymorphic functions into 654 # another layer of `tf.function`. 655 functions = { 656 "_create_resource": self._create_resource, 657 "_initialize": self._initialize, 658 "_destroy_resource": self._destroy_resource, 659 } 660 return functions 661 662 663def _call_attribute(instance, *args, **kwargs): 664 return instance.__call__(*args, **kwargs) 665 666 667def _setattr_and_track(obj, name, value): 668 """Sets new attribute and marks it as a dependency if Trackable.""" 669 setattr(obj, name, value) 670 if isinstance(value, base.Trackable): 671 obj._track_trackable(value, name) # pylint:disable=protected-access 672 673 674@tf_export("__internal__.saved_model.load_partial", v1=[]) 675def load_partial(export_dir, filters, tags=None, options=None): 676 """Partially load a SavedModel (saved from V2). 677 678 Similar to `tf.saved_model.load`, but with an additional argument that 679 lets you specify which nodes to load. 680 `tf.saved_model.load_partial(export_dir, ["root"])` and 681 `tf.saved_model.load(export_dir)` are equivalent. 682 683 Note: This only works for SavedModels saved with TensorFlow V2 from 684 `tf.saved_model.save` or Keras. This will not load SavedModels save from 685 the Estimator API. 686 687 In Tensorflow V2, SavedModel stores the **object graph** of the saved object. 688 The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras 689 layers, etc.) and edges that are the name of the attributes connecting the 690 objects. 691 692 *Example 1* 693 694 ``` 695 model = tf.Module() 696 model.child_layer = tf.Module() 697 model.child_layer.v = tf.Variable(5.) 698 tf.saved_model.save(model, '/tmp/model') 699 loaded = tf.__internal__.saved_model.load_partial( 700 ... '/tmp/model', 701 ... ['root.child_layer', 'root.child_layer.v']) 702 loaded['root.child_layer'].v.numpy() 703 5. 704 loaded['root.child_layer'].v is loaded['root.child_layer.v'] 705 True 706 707 *Example 2* 708 model = tf.Module() 709 model.child_layer = tf.Module() 710 model.child_layer.v = tf.Variable(5.) 711 >>> 712 tf.saved_model.save(model, '/tmp/model') 713 # Create a variable 714 new_variable = tf.Variable(0.) 715 loaded = tf.__internal__.saved_model.load_partial( 716 ... '/tmp/model', 717 ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) 718 loaded['root.child_layer'].v.numpy() 719 5. 720 new_variable.numpy() 721 5. 722 ``` 723 724 **Loading under different distribution strategies** 725 You can load different parts of the model under different distribution 726 strategies. Note that this is very experimental so use with care. 727 728 ``` 729 model = tf.Module() 730 model.layer_1 = tf.Module() 731 model.layer_1.v = tf.Variable(5.) 732 model.layer_2 = tf.Module() 733 model.layer_2.v = tf.Variable(7.) 734 tf.saved_model.save(model, '/tmp/model') 735 # Load with no strategy 736 loaded = tf.__internal__.saved_model.load_partial( 737 ... '/tmp/model', 738 ... ['root.layer_1']) 739 loaded['root.layer_1'].v 740 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> 741 strategy = tf.distribute.MirroredStrategy() 742 with strategy.scope(): 743 ... loaded2 = tf.__internal__.saved_model.load_partial( 744 ... '/tmp/model', 745 ... ['root.layer_2']) 746 loaded2['root.layer_2'].v 747 MirroredVariable:{ 748 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> 749 } 750 ``` 751 752 Args: 753 export_dir: The SavedModel directory to load from. 754 filters: A list or dictionary where each element or key is a string 755 path to nodes that should be loaded. Node paths consist of all the child 756 attribute names to reach that node in the form: `root.{attribute_name}`. 757 The loader will load all of the specified nodes and their recursive 758 descendants. When this option is defined, the loader will return a 759 dictionary mapping the node paths to the loaded objects. 760 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 761 if the SavedModel contains a single MetaGraph, as for those exported from 762 `tf.saved_model.save`. 763 options: `tf.saved_model.LoadOptions` object that specifies options for 764 loading. 765 766 Returns: 767 A dictionary mapping node paths from the filter to loaded objects. 768 """ 769 return load_internal(export_dir, tags, options, filters=filters) 770 771 772@tf_export("saved_model.load", v1=["saved_model.load_v2"]) 773def load(export_dir, tags=None, options=None): 774 """Load a SavedModel from `export_dir`. 775 776 Signatures associated with the SavedModel are available as functions: 777 778 ```python 779 imported = tf.saved_model.load(path) 780 f = imported.signatures["serving_default"] 781 print(f(x=tf.constant([[1.]]))) 782 ``` 783 784 Objects exported with `tf.saved_model.save` additionally have trackable 785 objects and functions assigned to attributes: 786 787 ```python 788 exported = tf.train.Checkpoint(v=tf.Variable(3.)) 789 exported.f = tf.function( 790 lambda x: exported.v * x, 791 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 792 tf.saved_model.save(exported, path) 793 imported = tf.saved_model.load(path) 794 assert 3. == imported.v.numpy() 795 assert 6. == imported.f(x=tf.constant(2.)).numpy() 796 ``` 797 798 _Loading Keras models_ 799 800 Keras models are trackable, so they can be saved to SavedModel. The object 801 returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have 802 `.fit`, `.predict`, etc. methods). A few attributes and functions are still 803 available: `.variables`, `.trainable_variables` and `.__call__`. 804 805 ```python 806 model = tf.keras.Model(...) 807 tf.saved_model.save(model, path) 808 imported = tf.saved_model.load(path) 809 outputs = imported(inputs) 810 ``` 811 812 Use `tf.keras.models.load_model` to restore the Keras model. 813 814 _Importing SavedModels from TensorFlow 1.x_ 815 816 SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat 817 graph instead of `tf.function` objects. These SavedModels will be loaded with 818 the following attributes: 819 820 * `.signatures`: A dictionary mapping signature names to functions. 821 * `.prune(feeds, fetches) `: A method which allows you to extract 822 functions for new subgraphs. This is equivalent to importing the SavedModel 823 and naming feeds and fetches in a Session from TensorFlow 1.x. 824 825 ```python 826 imported = tf.saved_model.load(path_to_v1_saved_model) 827 pruned = imported.prune("x:0", "out:0") 828 pruned(tf.ones([])) 829 ``` 830 831 See `tf.compat.v1.wrap_function` for details. 832 * `.variables`: A list of imported variables. 833 * `.graph`: The whole imported graph. 834 * `.restore(save_path)`: A function that restores variables from a checkpoint 835 saved from `tf.compat.v1.Saver`. 836 837 _Consuming SavedModels asynchronously_ 838 839 When consuming SavedModels asynchronously (the producer is a separate 840 process), the SavedModel directory will appear before all files have been 841 written, and `tf.saved_model.load` will fail if pointed at an incomplete 842 SavedModel. Rather than checking for the directory, check for 843 "saved_model_dir/saved_model.pb". This file is written atomically as the last 844 `tf.saved_model.save` file operation. 845 846 Args: 847 export_dir: The SavedModel directory to load from. 848 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 849 if the SavedModel contains a single MetaGraph, as for those exported from 850 `tf.saved_model.save`. 851 options: `tf.saved_model.LoadOptions` object that specifies options for 852 loading. 853 854 Returns: 855 A trackable object with a `signatures` attribute mapping from signature 856 keys to functions. If the SavedModel was exported by `tf.saved_model.save`, 857 it also points to trackable objects, functions, debug info which it has been 858 saved. 859 860 Raises: 861 ValueError: If `tags` don't match a MetaGraph in the SavedModel. 862 """ 863 result = load_internal(export_dir, tags, options)["root"] 864 return result 865 866 867def load_internal(export_dir, tags=None, options=None, loader_cls=Loader, 868 filters=None): 869 """Loader implementation.""" 870 options = options or load_options.LoadOptions() 871 if tags is not None and not isinstance(tags, set): 872 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered 873 # sequences for nest.flatten, so we put those through as-is. 874 tags = nest.flatten(tags) 875 saved_model_proto, debug_info = ( 876 loader_impl.parse_saved_model_with_debug_info(export_dir)) 877 878 if (len(saved_model_proto.meta_graphs) == 1 and 879 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 880 metrics.IncrementReadApi(_LOAD_V2_LABEL) 881 meta_graph_def = saved_model_proto.meta_graphs[0] 882 # tensor_content field contains raw bytes in litle endian format 883 # which causes problems when loaded on big-endian systems 884 # requiring byteswap 885 if sys.byteorder == "big": 886 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", 887 "big") 888 if (tags is not None 889 and set(tags) != set(meta_graph_def.meta_info_def.tags)): 890 raise ValueError( 891 "Got an incompatible argument to `tags`: {tags}. The SavedModel at " 892 f"{export_dir} has one MetaGraph with tags " 893 f"{meta_graph_def.meta_info_def.tags}. You may omit the argument, " 894 "pass 'None', or pass matching tags.") 895 object_graph_proto = meta_graph_def.object_graph_def 896 897 ckpt_options = checkpoint_options.CheckpointOptions( 898 experimental_io_device=options.experimental_io_device) 899 with ops.init_scope(): 900 try: 901 loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, 902 ckpt_options, options, filters) 903 except errors.NotFoundError as err: 904 raise FileNotFoundError( 905 str(err) + "\n You may be trying to load on a different device " 906 "from the computational device. Consider setting the " 907 "`experimental_io_device` option in `tf.saved_model.LoadOptions` " 908 "to the io_device such as '/job:localhost'.") 909 root = loader.get(0) 910 if isinstance(loader, Loader): 911 root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) 912 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version 913 root.tensorflow_git_version = ( 914 meta_graph_def.meta_info_def.tensorflow_git_version) 915 metrics.IncrementRead(write_version="2") 916 else: 917 if filters: 918 raise ValueError("SavedModels saved from Tensorflow 1.x or Estimator (any" 919 " version) cannot be loaded with node filters.") 920 with ops.init_scope(): 921 root = load_v1_in_v2.load(export_dir, tags) 922 root.graph_debug_info = debug_info 923 924 if filters: 925 return {node_id: loader.get(node_id) for node_id in filters} 926 else: 927 return {"root": root} 928