1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras SavedModel deserialization.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import re 22import types 23 24from google.protobuf import message 25 26from tensorflow.core.framework import versions_pb2 27from tensorflow.python.eager import context 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.keras import backend 33from tensorflow.python.keras import regularizers 34from tensorflow.python.keras.engine import input_spec 35from tensorflow.python.keras.protobuf import saved_metadata_pb2 36from tensorflow.python.keras.saving import saving_utils 37from tensorflow.python.keras.saving.saved_model import constants 38from tensorflow.python.keras.saving.saved_model import json_utils 39from tensorflow.python.keras.saving.saved_model import utils 40from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints 41from tensorflow.python.keras.utils import generic_utils 42from tensorflow.python.keras.utils import metrics_utils 43from tensorflow.python.keras.utils.generic_utils import LazyLoader 44from tensorflow.python.ops.ragged import ragged_tensor 45from tensorflow.python.platform import gfile 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.saved_model import load as tf_load 48from tensorflow.python.saved_model import loader_impl 49from tensorflow.python.saved_model import nested_structure_coder 50from tensorflow.python.saved_model import revived_types 51from tensorflow.python.training.tracking import base as trackable 52from tensorflow.python.training.tracking import data_structures 53from tensorflow.python.training.tracking.tracking import delete_tracking 54from tensorflow.python.util import compat 55from tensorflow.python.util import nest 56 57# To avoid circular dependencies between keras/engine and keras/saving, 58# code in keras/saving must delay imports. 59 60# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 61# once the issue with copybara is fixed. 62# pylint:disable=g-inconsistent-quotes 63models_lib = LazyLoader("models_lib", globals(), 64 "tensorflow.python.keras.models") 65base_layer = LazyLoader( 66 "base_layer", globals(), 67 "tensorflow.python.keras.engine.base_layer") 68layers_module = LazyLoader( 69 "layers_module", globals(), 70 "tensorflow.python.keras.layers") 71input_layer = LazyLoader( 72 "input_layer", globals(), 73 "tensorflow.python.keras.engine.input_layer") 74functional_lib = LazyLoader( 75 "functional_lib", globals(), 76 "tensorflow.python.keras.engine.functional") 77training_lib = LazyLoader( 78 "training_lib", globals(), 79 "tensorflow.python.keras.engine.training") 80training_lib_v1 = LazyLoader( 81 "training_lib_v1", globals(), 82 "tensorflow.python.keras.engine.training_v1") 83metrics = LazyLoader("metrics", globals(), 84 "tensorflow.python.keras.metrics") 85recurrent = LazyLoader( 86 "recurrent", globals(), 87 "tensorflow.python.keras.layers.recurrent") 88# pylint:enable=g-inconsistent-quotes 89 90 91PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union( 92 CommonEndpoints.all_checkpointable_objects) 93PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR) 94 95 96def load(path, compile=True, options=None): # pylint: disable=redefined-builtin 97 """Loads Keras objects from a SavedModel. 98 99 Any Keras layer or model saved to the SavedModel will be loaded back 100 as Keras objects. Other objects are loaded as regular trackable objects (same 101 as `tf.saved_model.load`). 102 103 Currently, Keras saving/loading only retains the Keras object's weights, 104 losses, and call function. 105 106 The loaded model can be re-compiled, but the original optimizer, compiled loss 107 functions, and metrics are not retained. This is temporary, and `model.save` 108 will soon be able to serialize compiled models. 109 110 Args: 111 path: Path to SavedModel. 112 compile: If true, compile the model after loading it. 113 options: Optional `tf.saved_model.LoadOptions` object that specifies 114 options for loading from SavedModel. 115 116 117 Returns: 118 Object loaded from SavedModel. 119 """ 120 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. 121 # TODO(kathywu): Add code to load from objects that contain all endpoints 122 123 # Look for metadata file or parse the SavedModel 124 metadata = saved_metadata_pb2.SavedMetadata() 125 meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] 126 object_graph_def = meta_graph_def.object_graph_def 127 path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) 128 if gfile.Exists(path_to_metadata_pb): 129 try: 130 with gfile.GFile(path_to_metadata_pb, 'rb') as f: 131 file_content = f.read() 132 metadata.ParseFromString(file_content) 133 except message.DecodeError as e: 134 raise IOError('Cannot parse keras metadata {}: {}.' 135 .format(path_to_metadata_pb, str(e))) 136 else: 137 logging.warning('SavedModel saved prior to TF 2.5 detected when loading ' 138 'Keras model. Please ensure that you are saving the model ' 139 'with model.save() or tf.keras.models.save_model(), *NOT* ' 140 'tf.saved_model.save(). To confirm, there should be a file ' 141 'named "keras_metadata.pb" in the SavedModel directory.') 142 _read_legacy_metadata(object_graph_def, metadata) 143 144 if not metadata.nodes: 145 # When there are no Keras objects, return the results from the core loader 146 return tf_load.load(path, options=options) 147 148 # Recreate layers and metrics using the info stored in the metadata. 149 keras_loader = KerasObjectLoader(metadata, object_graph_def) 150 keras_loader.load_layers(compile=compile) 151 152 # Generate a dictionary of all loaded nodes. 153 nodes_to_load = {'root': None} 154 for node_id, loaded_node in keras_loader.loaded_nodes.items(): 155 nodes_to_load[keras_loader.get_path(node_id)] = loaded_node 156 loaded = tf_load.load_partial(path, nodes_to_load, options=options) 157 158 # Finalize the loaded layers and remove the extra tracked dependencies. 159 keras_loader.finalize_objects() 160 keras_loader.del_tracking() 161 162 model = loaded['root'] 163 164 # pylint: disable=protected-access 165 if isinstance(model, training_lib.Model) and compile: 166 # TODO(kathywu): Use compiled objects from SavedModel, instead of 167 # creating new objects from the training config. 168 training_config = model._serialized_attributes['metadata'].get( 169 'training_config', None) 170 if training_config is not None: 171 model.compile(**saving_utils.compile_args_from_training_config( 172 training_config)) 173 saving_utils.try_build_compiled_arguments(model) 174 else: 175 logging.warning('No training configuration found in save file, so the ' 176 'model was *not* compiled. Compile it manually.') 177 # pylint: enable=protected-access 178 179 # Force variables and resources to initialize. 180 if not context.executing_eagerly(): 181 sess = backend.get_session() # Variables are initialized by this call. 182 sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)) 183 184 return model 185 186 187def _read_legacy_metadata(object_graph_def, metadata): 188 """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" 189 # Older SavedModels store the metadata directly in the proto instead of the 190 # separate pb file. 191 node_paths = _generate_object_paths(object_graph_def) 192 for node_id, proto in enumerate(object_graph_def.nodes): 193 if (proto.WhichOneof('kind') == 'user_object' and 194 proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS): 195 metadata.nodes.add( 196 node_id=node_id, 197 node_path=node_paths[node_id], 198 version=versions_pb2.VersionDef( 199 producer=1, min_consumer=1, bad_consumers=[]), 200 identifier=proto.user_object.identifier, 201 metadata=proto.user_object.metadata) 202 203 204def _generate_object_paths(object_graph_def): 205 """Traverses through an ObjectGraphDef and builds a map of all node paths.""" 206 paths = {0: 'root'} 207 nodes_to_visit = [0] 208 209 while nodes_to_visit: 210 current_node = nodes_to_visit.pop() 211 current_path = paths[current_node] 212 for reference in object_graph_def.nodes[current_node].children: 213 if reference.node_id in paths: 214 continue 215 paths[reference.node_id] = '{}.{}'.format(current_path, 216 reference.local_name) 217 nodes_to_visit.append(reference.node_id) 218 219 return paths 220 221 222def _is_graph_network(layer): 223 """Determines whether the layer is a graph network.""" 224 # pylint: disable=protected-access 225 if isinstance(layer, RevivedNetwork): 226 return False 227 elif isinstance(layer, functional_lib.Functional): 228 return (layer._is_graph_network or 229 isinstance(layer, models_lib.Sequential)) 230 return False 231 232 233class KerasObjectLoader(object): 234 """Loader that recreates Keras objects (e.g. layers, models). 235 236 Layers and models are revived from either the config or SavedModel following 237 these rules: 238 1. If object is a graph network (i.e. Sequential or Functional) then it will 239 be initialized using the structure from the config only after the children 240 layers have been created. Graph networks must be initialized with inputs 241 and outputs, so all child layers must be created beforehand. 242 2. If object's config exists and the class can be found, then revive from 243 config. 244 3. Object may have already been created if its parent was revived from config. 245 In this case, do nothing. 246 4. If nothing of the above applies, compose the various artifacts from the 247 SavedModel to create a subclassed layer or model. At this time, custom 248 metrics are not supported. 249 250 """ 251 252 def __init__(self, metadata, object_graph_def): 253 self._metadata = metadata 254 self._proto = object_graph_def 255 256 self._node_paths = {node_data.node_id: node_data.node_path 257 for node_data in metadata.nodes} 258 self.loaded_nodes = {} # Maps node path -> loaded node 259 260 # Store all node ids that have already been traversed when tracking nodes 261 # that were recreated from the config. 262 self._traversed_nodes_from_config = set() 263 264 # Maps model id -> (blank model obj, list of child layer or their node ids) 265 # This tracks all layers in functional and sequential models. These models 266 # are only reconstructed after all of their child layers have been created. 267 self.model_layer_dependencies = {} 268 self._models_to_reconstruct = [] 269 270 def del_tracking(self): 271 """Removes tracked references that are only used when loading the model.""" 272 # Now that the node object has been fully loaded, and the checkpoint has 273 # been restored, the object no longer needs to track objects added from 274 # SerializedAttributes. (Note that saving a training checkpoint still 275 # functions correctly, because layers and variables are tracked separately 276 # by the Layer object.) 277 # TODO(kathywu): Instead of outright deleting these nodes (which would 278 # make restoring from a different checkpoint tricky), mark them as extra 279 # dependencies that are OK to overwrite. 280 for node in self.loaded_nodes.values(): 281 node = node[0] 282 if not isinstance(node, base_layer.Layer): 283 # Loaded nodes can contain other trackable objects created when 284 # loading layers from the config, such as variables. 285 continue 286 for name in PUBLIC_ATTRIBUTES: 287 delete_tracking(node, name) 288 289 if isinstance(node, functional_lib.Functional): 290 # Delete the temporary layer dependencies, which were used to restore 291 # the checkpointed values. When the model is live, the user can delete 292 # or add layers to the model at any time, so these layer dependencies 293 # may be obsolete. 294 dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access 295 for name in dependencies: 296 if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: 297 delete_tracking(node, name) 298 299 def _add_children_recreated_from_config(self, obj, proto, node_id): 300 """Recursively records objects recreated from config.""" 301 # pylint: disable=protected-access 302 if node_id in self._traversed_nodes_from_config: 303 return 304 305 parent_path = self._node_paths[node_id] 306 self._traversed_nodes_from_config.add(node_id) 307 obj._maybe_initialize_trackable() 308 if isinstance(obj, base_layer.Layer) and not obj.built: 309 metadata = json_utils.decode(proto.user_object.metadata) 310 self._try_build_layer(obj, node_id, metadata.get('build_input_shape')) 311 312 # Create list of all possible children 313 children = [] 314 # Look for direct children 315 for reference in proto.children: 316 obj_child = obj._lookup_dependency(reference.local_name) 317 children.append((obj_child, reference.node_id, reference.local_name)) 318 319 # Add metrics that may have been added to the layer._metrics list. 320 # This is stored in the SavedModel as layer.keras_api.layer_metrics in 321 # SavedModels created after Tf 2.2. 322 metric_list_node_id = self._search_for_child_node( 323 node_id, [constants.KERAS_ATTR, 'layer_metrics']) 324 if metric_list_node_id is not None and hasattr(obj, '_metrics'): 325 obj_metrics = {m.name: m for m in obj._metrics} 326 for reference in self._proto.nodes[metric_list_node_id].children: 327 metric = obj_metrics.get(reference.local_name) 328 if metric is not None: 329 metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR, 330 reference.local_name) 331 children.append((metric, reference.node_id, metric_path)) 332 333 for (obj_child, child_id, child_name) in children: 334 child_proto = self._proto.nodes[child_id] 335 336 if not isinstance(obj_child, trackable.Trackable): 337 continue 338 if (child_proto.user_object.identifier in 339 revived_types.registered_identifiers()): 340 setter = revived_types.get_setter(child_proto.user_object) 341 elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS: 342 setter = _revive_setter 343 else: 344 setter = setattr 345 # pylint: enable=protected-access 346 347 if child_id in self.loaded_nodes: 348 if self.loaded_nodes[child_id][0] is not obj_child: 349 # This means that the same trackable object is referenced by two 350 # different objects that were recreated from the config. 351 logging.warn('Looks like there is an object (perhaps variable or ' 352 'layer) that is shared between different layers/models. ' 353 'This may cause issues when restoring the variable ' 354 'values. Object: {}'.format(obj_child)) 355 continue 356 357 # Overwrite variable names with the ones saved in the SavedModel. 358 if (child_proto.WhichOneof('kind') == 'variable' and 359 child_proto.variable.name): 360 obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access 361 362 if isinstance(obj_child, data_structures.TrackableDataStructure): 363 setter = lambda *args: None 364 365 child_path = '{}.{}'.format(parent_path, child_name) 366 self._node_paths[child_id] = child_path 367 self._add_children_recreated_from_config( 368 obj_child, child_proto, child_id) 369 self.loaded_nodes[child_id] = obj_child, setter 370 371 def load_layers(self, compile=True): # pylint: disable=redefined-builtin 372 """Load all layer nodes from the metadata.""" 373 # Load metrics after models and layers, since it's likely that models 374 # and layers will create the metric when initialized (this avoids wasting 375 # time by creating objects multiple times). 376 metric_list = [] 377 for node_metadata in self._metadata.nodes: 378 if node_metadata.identifier == constants.METRIC_IDENTIFIER: 379 metric_list.append(node_metadata) 380 continue 381 382 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 383 node_metadata.node_id, node_metadata.identifier, 384 node_metadata.metadata) 385 386 for node_metadata in metric_list: 387 try: 388 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 389 node_metadata.node_id, node_metadata.identifier, 390 node_metadata.metadata) 391 except ValueError: 392 # Metrics are only needed when the model is compiled later. We ignore 393 # errors when trying to load custom metrics when `compile=False` until 394 # custom metrics are serialized properly (b/135550038). 395 if compile: 396 raise 397 logging.warning('Unable to restore custom metric. Please ensure that ' 398 'the layer implements `get_config` and `from_config` ' 399 'when saving. In addition, please use the ' 400 '`custom_objects` arg when calling `load_model()`.') 401 402 def _load_layer(self, node_id, identifier, metadata): 403 """Load a single layer from a SavedUserObject proto.""" 404 metadata = json_utils.decode(metadata) 405 406 # If node was already created 407 if node_id in self.loaded_nodes: 408 node, setter = self.loaded_nodes[node_id] 409 410 # Revive setter requires the object to have a `_serialized_attributes` 411 # property. Add it here. 412 _maybe_add_serialized_attributes(node, metadata) 413 414 config = metadata.get('config') 415 if _is_graph_network(node) and generic_utils.validate_config(config): 416 child_nodes = self._get_child_layer_node_ids(node_id) 417 self.model_layer_dependencies[node_id] = (node, child_nodes) 418 if not child_nodes: 419 self._models_to_reconstruct.append(node_id) 420 return node, setter 421 422 # Detect whether this object can be revived from the config. If not, then 423 # revive from the SavedModel instead. 424 obj, setter = self._revive_from_config(identifier, metadata, node_id) 425 if obj is None: 426 obj, setter = revive_custom_object(identifier, metadata) 427 428 # Add an attribute that stores the extra functions/objects saved in the 429 # SavedModel. Most of these functions/objects are ignored, but some are 430 # used later in the loading process (e.g. the list of regularization 431 # losses, or the training config of compiled models). 432 _maybe_add_serialized_attributes(obj, metadata) 433 return obj, setter 434 435 def _revive_from_config(self, identifier, metadata, node_id): 436 """Revives a layer/model from config, or returns None.""" 437 if identifier == constants.METRIC_IDENTIFIER: 438 obj = self._revive_metric_from_config(metadata) 439 else: 440 obj = ( 441 self._revive_graph_network(identifier, metadata, node_id) or 442 self._revive_layer_or_model_from_config(metadata, node_id)) 443 444 if obj is None: 445 return None, None 446 447 setter = self._config_node_setter(_revive_setter) 448 self._add_children_recreated_from_config( 449 obj, self._proto.nodes[node_id], node_id) 450 return obj, setter 451 452 def _revive_graph_network(self, identifier, metadata, node_id): 453 """Revives a graph network from config.""" 454 # Determine whether the metadata contains information for reviving a 455 # functional or Sequential model. 456 config = metadata.get('config') 457 if not generic_utils.validate_config(config): 458 return None 459 460 class_name = compat.as_str(metadata['class_name']) 461 if generic_utils.get_registered_object(class_name) is not None: 462 return None 463 model_is_functional_or_sequential = ( 464 metadata.get('is_graph_network', False) or 465 class_name == 'Sequential' or 466 class_name == 'Functional') 467 if not model_is_functional_or_sequential: 468 return None 469 470 # Revive functional and sequential models as blank model objects for now ( 471 # must be initialized to enable setattr tracking and attribute caching). 472 # Reconstruction of the network is deferred until all of the model's layers 473 # have been revived. 474 if class_name == 'Sequential': 475 model = models_lib.Sequential(name=config['name']) 476 # The model is a custom Sequential model. 477 elif identifier == constants.SEQUENTIAL_IDENTIFIER: 478 # Uses the custom class name, since the config does not have one. 479 model = models_lib.Sequential(name=class_name) 480 else: 481 model = models_lib.Functional( 482 inputs=[], outputs=[], name=config['name']) 483 484 # Record this model and its layers. This will later be used to reconstruct 485 # the model. 486 layers = self._get_child_layer_node_ids(node_id) 487 self.model_layer_dependencies[node_id] = (model, layers) 488 if not layers: 489 self._models_to_reconstruct.append(node_id) 490 return model 491 492 def _revive_layer_or_model_from_config(self, metadata, node_id): 493 """Revives a layer/custom model from config; returns None if infeasible.""" 494 # Check that the following requirements are met for reviving from config: 495 # 1. Object can be deserialized from config. 496 # 2. If the object needs to be built, then the build input shape can be 497 # found. 498 class_name = metadata.get('class_name') 499 config = metadata.get('config') 500 shared_object_id = metadata.get('shared_object_id') 501 must_restore_from_config = metadata.get('must_restore_from_config') 502 if not generic_utils.validate_config(config): 503 return None 504 505 try: 506 obj = layers_module.deserialize( 507 generic_utils.serialize_keras_class_and_config( 508 class_name, config, shared_object_id=shared_object_id)) 509 except ValueError: 510 if must_restore_from_config: 511 raise RuntimeError( 512 'Unable to restore a layer of class {cls}. Layers of ' 513 'class {cls} require that the class be provided to ' 514 'the model loading code, either by registering the ' 515 'class using @keras.utils.register_keras_serializable ' 516 'on the class def and including that file in your ' 517 'program, or by passing the class in a ' 518 'keras.utils.CustomObjectScope that wraps this load ' 519 'call.'.format(cls=class_name)) 520 else: 521 return None 522 523 # Use the dtype, name, and trainable status. Often times these are not 524 # specified in custom configs, so retrieve their values from the metadata. 525 # pylint: disable=protected-access 526 obj._name = metadata['name'] 527 if metadata.get('trainable') is not None: 528 obj.trainable = metadata['trainable'] 529 if metadata.get('dtype') is not None: 530 obj._set_dtype_policy(metadata['dtype']) 531 if metadata.get('stateful') is not None: 532 obj.stateful = metadata['stateful'] 533 # Restore model save spec for subclassed models. (layers do not store a 534 # SaveSpec) 535 if isinstance(obj, training_lib.Model): 536 save_spec = metadata.get('save_spec') 537 if save_spec is not None: 538 obj._set_save_spec(save_spec) 539 # pylint: enable=protected-access 540 541 build_input_shape = metadata.get('build_input_shape') 542 built = self._try_build_layer(obj, node_id, build_input_shape) 543 544 if not built: 545 # If the layer cannot be built, revive a custom layer instead. 546 return None 547 return obj 548 549 def _revive_metric_from_config(self, metadata): 550 """Revives a metric object using the config saved in the metadata.""" 551 class_name = compat.as_str(metadata['class_name']) 552 config = metadata.get('config') 553 554 if not generic_utils.validate_config(config): 555 return None 556 557 try: 558 obj = metrics.deserialize( 559 generic_utils.serialize_keras_class_and_config(class_name, config)) 560 except ValueError: 561 return None 562 563 build_input_shape = metadata.get('build_input_shape') 564 if build_input_shape is not None and hasattr(obj, '_build'): 565 obj._build(build_input_shape) # pylint: disable=protected-access 566 567 return obj 568 569 def _try_build_layer(self, obj, node_id, build_input_shape): 570 """Attempts to build the layer.""" 571 if obj.built or hasattr(obj.build, '_is_default'): 572 obj.built = True 573 return True 574 575 if build_input_shape is None: 576 build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True) 577 578 if build_input_shape is not None: 579 obj.build(build_input_shape) 580 base_layer.Layer.build(obj, build_input_shape) 581 return True 582 583 return False 584 585 def _load_edges(self): 586 """Add edges for all nodes that are not waiting on initialization.""" 587 for node_id, proto in enumerate(self._proto.nodes): 588 if node_id not in self.model_layer_dependencies: 589 self._add_object_graph_edges(proto, node_id) 590 591 def get_path(self, node_id): 592 return self._node_paths[node_id] 593 594 def finalize_objects(self): 595 """Finish setting up Keras objects. 596 597 This function is executed after all objects and functions have been created. 598 Call functions and losses are attached to each layer, and once all layers 599 have been fully set up, graph networks are initialized. 600 601 Subclassed models that are revived from the SavedModel are treated like 602 layers, and have their call/loss functions attached here. 603 """ 604 # Finish setting up layers and subclassed models. This step attaches call 605 # functions and losses to each object, and sets model inputs/outputs. 606 layers_revived_from_config = [] 607 layers_revived_from_saved_model = [] 608 for node_id, (node, _) in self.loaded_nodes.items(): 609 if (not isinstance(node, base_layer.Layer) or 610 # Don't finalize models until all layers have finished loading. 611 node_id in self.model_layer_dependencies): 612 continue 613 614 self._unblock_model_reconstruction(node_id, node) 615 616 if isinstance(node, input_layer.InputLayer): 617 continue 618 elif isinstance(node, metrics.Metric): 619 continue 620 621 if isinstance(node, (RevivedLayer, RevivedInputLayer)): 622 layers_revived_from_saved_model.append(node) 623 else: 624 layers_revived_from_config.append(node) 625 626 _finalize_saved_model_layers(layers_revived_from_saved_model) 627 _finalize_config_layers(layers_revived_from_config) 628 629 # Initialize graph networks, now that layer dependencies have been resolved. 630 self._reconstruct_all_models() 631 632 def _unblock_model_reconstruction(self, layer_id, layer): 633 """Removes layer from blocking model reconstruction.""" 634 for model_id, v in self.model_layer_dependencies.items(): 635 _, layers = v 636 if layer_id not in layers: 637 continue 638 layers[layers.index(layer_id)] = layer 639 if all(isinstance(x, base_layer.Layer) for x in layers): 640 self._models_to_reconstruct.append(model_id) 641 642 def _reconstruct_all_models(self): 643 """Reconstructs the network structure of all models.""" 644 all_initialized_models = set() 645 while self._models_to_reconstruct: 646 model_id = self._models_to_reconstruct.pop(0) 647 all_initialized_models.add(model_id) 648 model, layers = self.model_layer_dependencies[model_id] 649 self._reconstruct_model(model_id, model, layers) 650 _finalize_config_layers([model]) 651 652 if all_initialized_models != set(self.model_layer_dependencies.keys()): 653 # This should not happen. 654 uninitialized_model_ids = ( 655 set(self.model_layer_dependencies.keys()) - all_initialized_models) 656 uninitialized_model_names = [ 657 self.model_layer_dependencies[model_id][0].name 658 for model_id in uninitialized_model_ids] 659 raise ValueError('Error when loading from SavedModel -- the following ' 660 'models could not be initialized: {}' 661 .format(uninitialized_model_names)) 662 663 def _reconstruct_model(self, model_id, model, layers): 664 """Reconstructs the network structure.""" 665 config = json_utils.decode( 666 self._proto.nodes[model_id].user_object.metadata)['config'] 667 668 # Set up model inputs 669 if model.inputs: 670 # Inputs may already be created if the model is instantiated in another 671 # object's __init__. 672 pass 673 elif isinstance(model, models_lib.Sequential): 674 if not layers or not isinstance(layers[0], input_layer.InputLayer): 675 if config['layers'][0]['class_name'] == 'InputLayer': 676 layers.insert(0, input_layer.InputLayer.from_config( 677 config['layers'][0]['config'])) 678 elif 'batch_input_shape' in config['layers'][0]['config']: 679 batch_input_shape = config['layers'][0]['config']['batch_input_shape'] 680 layers.insert(0, input_layer.InputLayer( 681 input_shape=batch_input_shape[1:], 682 batch_size=batch_input_shape[0], 683 dtype=layers[0].dtype, 684 name=layers[0].name + '_input')) 685 model.__init__(layers, name=config['name']) 686 if not model.inputs: 687 first_layer = self._get_child_layer_node_ids(model_id)[0] 688 input_specs = self._infer_inputs(first_layer) 689 input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) 690 model._set_inputs(input_specs) # pylint: disable=protected-access 691 if not model.built and not isinstance(input_specs, dict): 692 model.build(input_shapes) 693 else: # Reconstruct functional model 694 (inputs, outputs, 695 created_layers) = functional_lib.reconstruct_from_config( 696 config, created_layers={layer.name: layer for layer in layers}) 697 model.__init__(inputs, outputs, name=config['name']) 698 functional_lib.connect_ancillary_layers(model, created_layers) 699 700 # Set model dtype and trainable status. 701 _set_network_attributes_from_metadata(model) 702 703 # Unblock models that are dependent on this model. 704 self._unblock_model_reconstruction(model_id, model) 705 706 def _get_child_layer_node_ids(self, node_id): 707 """Returns the node ids of each layer in a Sequential/Functional model.""" 708 # Sequential and Functional track layers with names following the format 709 # "layer-N". Use this to generate the list of layers. 710 num_layers = 0 711 child_layers = {} 712 pattern = re.compile('layer-(\\d+)') 713 714 for child in self._proto.nodes[node_id].children: 715 m = pattern.match(child.local_name) 716 if m is None: 717 continue 718 layer_n = int(m.group(1)) 719 num_layers = max(layer_n + 1, num_layers) 720 child_layers[layer_n] = child.node_id 721 722 ordered = [] 723 for n in range(num_layers): 724 child = child_layers.get(n) 725 if child is None: 726 break 727 ordered.append(child) 728 return ordered 729 730 def _search_for_child_node(self, parent_id, path_to_child): 731 """Returns node id of child node. 732 733 A helper method for traversing the object graph proto. 734 735 As an example, say that the object graph proto in the SavedModel contains an 736 object with the following child and grandchild attributes: 737 738 `parent.child_a.child_b` 739 740 This method can be used to retrieve the node id of `child_b` using the 741 parent's node id by calling: 742 743 `_search_for_child_node(parent_id, ['child_a', 'child_b'])`. 744 745 Args: 746 parent_id: node id of parent node 747 path_to_child: list of children names. 748 749 Returns: 750 node_id of child, or None if child isn't found. 751 """ 752 if not path_to_child: 753 return parent_id 754 755 for child in self._proto.nodes[parent_id].children: 756 if child.local_name == path_to_child[0]: 757 return self._search_for_child_node(child.node_id, path_to_child[1:]) 758 return None 759 760 def _infer_inputs(self, layer_node_id, convert_to_shapes=False): 761 """Infers input shape of layer from SavedModel functions.""" 762 coder = nested_structure_coder.StructureCoder() 763 call_fn_id = self._search_for_child_node( 764 layer_node_id, ['call_and_return_all_conditional_losses']) 765 if call_fn_id is None: 766 return None 767 768 concrete_functions = ( 769 self._proto.nodes[call_fn_id].function.concrete_functions) 770 if not concrete_functions: 771 return None 772 call_fn_name = concrete_functions[0] 773 call_fn_proto = self._proto.concrete_functions[call_fn_name] 774 structured_input_signature = coder.decode_proto( 775 call_fn_proto.canonicalized_input_signature) 776 inputs = structured_input_signature[0][0] 777 if convert_to_shapes: 778 return nest.map_structure(lambda spec: spec.shape, inputs) 779 else: 780 return inputs 781 782 def _config_node_setter(self, setter): 783 """Creates edges for nodes that are recreated from config.""" 784 def setattr_wrapper(obj, name, value): 785 # Avoid overwriting attributes of objects recreated from the config. 786 if obj._lookup_dependency(name) is None: # pylint: disable=protected-access 787 setter(obj, name, value) 788 return setattr_wrapper 789 790 791def _finalize_saved_model_layers(layers): 792 """Runs the final steps of loading Keras Layers from SavedModel.""" 793 # pylint: disable=protected-access 794 # 1. Set up call functions for all layers (skip this step for Sequential and 795 # Functional models). 796 for layer in layers: 797 layer.built = True 798 if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'): 799 layer.call = utils.use_wrapped_call( 800 layer, _get_keras_attr(layer).call_and_return_conditional_losses, 801 return_method=True) 802 layer._init_call_fn_args() 803 else: 804 layer.call = types.MethodType( 805 _unable_to_call_layer_due_to_serialization_issue, layer) 806 807 for layer in layers: 808 # 2. Set model inputs and outputs. 809 if isinstance(layer, RevivedNetwork): 810 _set_network_attributes_from_metadata(layer) 811 812 if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'): 813 call_fn = _get_keras_attr(layer).call_and_return_conditional_losses 814 if call_fn.input_signature is None: 815 inputs = infer_inputs_from_restored_call_function(call_fn) 816 else: 817 inputs = call_fn.input_signature[0] 818 layer._set_inputs(inputs) # pylint: disable=protected-access 819 820 # 3. Add losses that aren't generated by the layer.call function. 821 _restore_layer_unconditional_losses(layer) 822 _restore_layer_activation_loss(layer) 823 824 # 4. Restore metrics list 825 _restore_layer_metrics(layer) 826 827 # pylint: enable=protected-access 828 829 830def _unable_to_call_layer_due_to_serialization_issue( 831 layer, *unused_args, **unused_kwargs): 832 """Replaces the `layer.call` if the layer was not fully serialized. 833 834 Keras Model/Layer serialization is relatively relaxed because SavedModels 835 are not always loaded back as keras models. Thus, when there is an issue 836 tracing a non-signature function, a warning is logged instead of raising an 837 error. This results in a SavedModel where the model's call function is saved, 838 but the internal layer call functions are not. 839 840 When deserialized with `tf.keras.models.load_model`, the internal layers 841 which do not have serialized call functions should raise an error when called. 842 843 Args: 844 layer: Layer without the serialized call function. 845 846 Raises: 847 ValueError 848 """ 849 850 raise ValueError( 851 'Cannot call custom layer {} of type {}, because the call function was ' 852 'not serialized to the SavedModel.' 853 'Please try one of the following methods to fix this issue:' 854 '\n\n(1) Implement `get_config` and `from_config` in the layer/model ' 855 'class, and pass the object to the `custom_objects` argument when ' 856 'loading the model. For more details, see: ' 857 'https://www.tensorflow.org/guide/keras/save_and_serialize' 858 '\n\n(2) Ensure that the subclassed model or layer overwrites `call` ' 859 'and not `__call__`. The input shape and dtype will be automatically ' 860 'recorded when the object is called, and used when saving. To manually ' 861 'specify the input shape/dtype, decorate the call function with ' 862 '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer))) 863 864 865def _finalize_config_layers(layers): 866 """Runs the final steps of loading Keras Layers from config.""" 867 for layer in layers: 868 # It is assumed that layers define their unconditional losses after being 869 # recreated from the config and built. The exceptions to this 870 # are Functional and Sequential models, which only store conditional losses 871 # (losses dependent on the inputs) in the config. Unconditional losses like 872 # weight regularization must be revived from the SavedModel. 873 if _is_graph_network(layer): 874 _restore_layer_unconditional_losses(layer) 875 876 # Some layers, like Dense, record their activation loss function in the 877 # config. However, not all layers do this, so the activation loss may be 878 # missing when restored from the config/hdf5. 879 # TODO(kathywu): Investigate ways to improve the config to ensure consistent 880 # loading behavior between HDF5 and SavedModel. 881 _restore_layer_activation_loss(layer) 882 883 # Restore metrics list. 884 _restore_layer_metrics(layer) 885 886 # Restore RNN layer states 887 if (isinstance(layer, recurrent.RNN) and 888 layer.stateful and 889 hasattr(_get_keras_attr(layer), 'states')): 890 layer.states = getattr(_get_keras_attr(layer), 'states', None) 891 for variable in nest.flatten(layer.states): 892 backend.track_variable(variable) 893 894 895def _finalize_metric(metric): 896 metric.update_state = types.MethodType(metrics_utils.update_state_wrapper( 897 metric.keras_api.update_state), metric) 898 metric.result = metric.keras_api.result 899 900 901def _restore_layer_unconditional_losses(layer): 902 """Restore unconditional losses from SavedModel.""" 903 if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'): 904 losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', []) 905 else: 906 # Some earlier SavedModels may not have layer_regularization_losses 907 # serialized separately. Fall back to using the regularization_losses 908 # list if it does not exist. 909 losses = layer._serialized_attributes.get('regularization_losses', []) # pylint: disable=protected-access 910 for loss in losses: 911 layer.add_loss(loss) 912 913 914def _restore_layer_activation_loss(layer): 915 """Restore actiation loss from SavedModel.""" 916 # Use wrapped activity regularizer function if the layer's activity 917 # regularizer wasn't created during initialization. 918 activity_regularizer = getattr(_get_keras_attr(layer), 919 'activity_regularizer_fn', None) 920 if activity_regularizer and not layer.activity_regularizer: 921 try: 922 layer.activity_regularizer = activity_regularizer 923 except AttributeError: 924 # This may happen if a layer wrapper is saved with an activity 925 # regularizer. The wrapper object's activity regularizer is unsettable. 926 pass 927 928 929def revive_custom_object(identifier, metadata): 930 """Revives object from SavedModel.""" 931 if ops.executing_eagerly_outside_functions(): 932 model_class = training_lib.Model 933 else: 934 model_class = training_lib_v1.Model 935 936 revived_classes = { 937 constants.INPUT_LAYER_IDENTIFIER: ( 938 RevivedInputLayer, input_layer.InputLayer), 939 constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer), 940 constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class), 941 constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional), 942 constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential), 943 } 944 parent_classes = revived_classes.get(identifier, None) 945 946 if parent_classes is not None: 947 parent_classes = revived_classes[identifier] 948 revived_cls = type( 949 compat.as_str(metadata['class_name']), parent_classes, {}) 950 return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access 951 else: 952 raise ValueError('Unable to restore custom object of type {} currently. ' 953 'Please make sure that the layer implements `get_config`' 954 'and `from_config` when saving. In addition, please use ' 955 'the `custom_objects` arg when calling `load_model()`.' 956 .format(identifier)) 957 958 959def _restore_layer_metrics(layer): 960 metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {}) 961 layer_metrics = {m.name: m for m in layer._metrics} # pylint: disable=protected-access 962 for name, metric in metrics_list.items(): 963 if name not in layer_metrics: 964 # Metrics may be added during initialization/building of custom layers. 965 layer._metrics.append(metric) # pylint: disable=protected-access 966 967 968# TODO(kathywu): Centrally define keys and functions for both serialization and 969# deserialization. 970class RevivedLayer(object): 971 """Keras layer loaded from a SavedModel.""" 972 973 @classmethod 974 def _init_from_metadata(cls, metadata): 975 """Create revived layer from metadata stored in the SavedModel proto.""" 976 init_args = dict( 977 name=metadata['name'], 978 trainable=metadata['trainable']) 979 if metadata.get('dtype') is not None: 980 init_args['dtype'] = metadata['dtype'] 981 if metadata.get('batch_input_shape') is not None: 982 init_args['batch_input_shape'] = metadata['batch_input_shape'] 983 984 revived_obj = cls(**init_args) 985 986 with trackable.no_automatic_dependency_tracking_scope(revived_obj): 987 # pylint:disable=protected-access 988 revived_obj._expects_training_arg = metadata['expects_training_arg'] 989 config = metadata.get('config') 990 if generic_utils.validate_config(config): 991 revived_obj._config = config 992 if metadata.get('input_spec') is not None: 993 revived_obj.input_spec = recursively_deserialize_keras_object( 994 metadata['input_spec'], 995 module_objects={'InputSpec': input_spec.InputSpec}) 996 if metadata.get('activity_regularizer') is not None: 997 revived_obj.activity_regularizer = regularizers.deserialize( 998 metadata['activity_regularizer']) 999 if metadata.get('_is_feature_layer') is not None: 1000 revived_obj._is_feature_layer = metadata['_is_feature_layer'] 1001 if metadata.get('stateful') is not None: 1002 revived_obj.stateful = metadata['stateful'] 1003 # pylint:enable=protected-access 1004 1005 return revived_obj, _revive_setter 1006 1007 @property 1008 def keras_api(self): 1009 return self._serialized_attributes.get(constants.KERAS_ATTR, None) 1010 1011 def get_config(self): 1012 if hasattr(self, '_config'): 1013 return self._config 1014 else: 1015 raise NotImplementedError 1016 1017 1018def _revive_setter(layer, name, value): 1019 """Setter function that saves some attributes to separate dictionary.""" 1020 # Many attributes in the SavedModel conflict with properties defined in 1021 # Layer and Model. Save these attributes to a separate dictionary. 1022 if name in PUBLIC_ATTRIBUTES: 1023 # pylint: disable=protected-access 1024 if isinstance(value, trackable.Trackable): 1025 layer._track_trackable(value, name=name) 1026 layer._serialized_attributes[name] = value 1027 # pylint: enable=protected-access 1028 elif (isinstance(layer, functional_lib.Functional) and 1029 re.match(r'^layer(_with_weights)?-[\d+]', name) is not None): 1030 # Edges named "layer-n" or "layer_with_weights-n", which are tracked in 1031 # network._track_layers, should not be added as an attribute. They should 1032 # be temporarily added as a dependency so that checkpointed values can be 1033 # restored. These dependencies are manually deleted in 1034 # KerasObjectLoader.del_tracking. 1035 layer._track_trackable(value, name) # pylint: disable=protected-access 1036 elif getattr(layer, name, None) is not None: 1037 # Don't overwrite already defined attributes. 1038 pass 1039 else: 1040 setattr(layer, name, value) 1041 1042 1043class RevivedInputLayer(object): 1044 """InputLayer loaded from a SavedModel.""" 1045 1046 @classmethod 1047 def _init_from_metadata(cls, metadata): 1048 """Revives the saved InputLayer from the Metadata.""" 1049 init_args = dict( 1050 name=metadata['name'], 1051 dtype=metadata['dtype'], 1052 sparse=metadata['sparse'], 1053 ragged=metadata['ragged'], 1054 batch_input_shape=metadata['batch_input_shape']) 1055 revived_obj = cls(**init_args) 1056 with trackable.no_automatic_dependency_tracking_scope(revived_obj): 1057 revived_obj._config = metadata['config'] # pylint:disable=protected-access 1058 1059 return revived_obj, setattr 1060 1061 def get_config(self): 1062 return self._config 1063 1064 1065def recursively_deserialize_keras_object(config, module_objects=None): 1066 """Deserialize Keras object from a nested structure.""" 1067 if isinstance(config, dict): 1068 if 'class_name' in config: 1069 return generic_utils.deserialize_keras_object( 1070 config, module_objects=module_objects) 1071 else: 1072 return {key: recursively_deserialize_keras_object(config[key], 1073 module_objects) 1074 for key in config} 1075 if isinstance(config, (tuple, list)): 1076 return [recursively_deserialize_keras_object(x, module_objects) 1077 for x in config] 1078 else: 1079 raise ValueError('Unable to decode config: {}'.format(config)) 1080 1081 1082def get_common_shape(x, y): 1083 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 1084 if x is None != y is None: 1085 raise RuntimeError( 1086 'Cannot find a common shape when LHS shape is None but RHS shape ' 1087 'is not (or vice versa): %s vs. %s' % (x, y)) 1088 if x is None: 1089 return None # The associated input was not a Tensor, no shape generated. 1090 if not isinstance(x, tensor_shape.TensorShape): 1091 raise TypeError('Expected x to be a TensorShape but saw %s' % (x,)) 1092 if not isinstance(y, tensor_shape.TensorShape): 1093 raise TypeError('Expected y to be a TensorShape but saw %s' % (y,)) 1094 if x.rank != y.rank or x.rank is None: 1095 return tensor_shape.TensorShape(None) 1096 dims = [] 1097 for dim_x, dim_y in zip(x.dims, y.dims): 1098 if (dim_x != dim_y 1099 or tensor_shape.dimension_value(dim_x) is None 1100 or tensor_shape.dimension_value(dim_y) is None): 1101 dims.append(None) 1102 else: 1103 dims.append(tensor_shape.dimension_value(dim_x)) 1104 return tensor_shape.TensorShape(dims) 1105 1106 1107def infer_inputs_from_restored_call_function(fn): 1108 """Returns TensorSpec of inputs from a restored call function. 1109 1110 Args: 1111 fn: Restored layer call function. It is assumed that the inputs are entirely 1112 in the first argument. 1113 1114 Returns: 1115 TensorSpec of call function inputs. 1116 """ 1117 def common_spec(x, y): 1118 common_shape = get_common_shape(x.shape, y.shape) 1119 if isinstance(x, sparse_tensor.SparseTensorSpec): 1120 return sparse_tensor.SparseTensorSpec(common_shape, x.dtype) 1121 elif isinstance(x, ragged_tensor.RaggedTensorSpec): 1122 return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype) 1123 return tensor_spec.TensorSpec(common_shape, x.dtype, x.name) 1124 1125 spec = fn.concrete_functions[0].structured_input_signature[0][0] 1126 for concrete in fn.concrete_functions[1:]: 1127 spec2 = concrete.structured_input_signature[0][0] 1128 spec = nest.map_structure(common_spec, spec, spec2) 1129 return spec 1130 1131 1132class RevivedNetwork(RevivedLayer): 1133 """Keras network of layers loaded from a SavedModel.""" 1134 1135 @classmethod 1136 def _init_from_metadata(cls, metadata): 1137 """Create revived network from metadata stored in the SavedModel proto.""" 1138 revived_obj = cls(name=metadata['name']) 1139 1140 # Store attributes revived from SerializedAttributes in a un-tracked 1141 # dictionary. The attributes are the ones listed in CommonEndpoints or 1142 # "keras_api" for keras-specific attributes. 1143 with trackable.no_automatic_dependency_tracking_scope(revived_obj): 1144 # pylint:disable=protected-access 1145 revived_obj._expects_training_arg = metadata['expects_training_arg'] 1146 config = metadata.get('config') 1147 if generic_utils.validate_config(config): 1148 revived_obj._config = config 1149 1150 if metadata.get('activity_regularizer') is not None: 1151 revived_obj.activity_regularizer = regularizers.deserialize( 1152 metadata['activity_regularizer']) 1153 # pylint:enable=protected-access 1154 1155 return revived_obj, _revive_setter # pylint:disable=protected-access 1156 1157 1158def _set_network_attributes_from_metadata(revived_obj): 1159 """Sets attributes recorded in the metadata.""" 1160 with trackable.no_automatic_dependency_tracking_scope(revived_obj): 1161 # pylint:disable=protected-access 1162 metadata = revived_obj._serialized_attributes['metadata'] 1163 if metadata.get('dtype') is not None: 1164 revived_obj._set_dtype_policy(metadata['dtype']) 1165 revived_obj.trainable = metadata['trainable'] 1166 # pylint:enable=protected-access 1167 1168 1169def _maybe_add_serialized_attributes(layer, metadata): 1170 # Store attributes revived from SerializedAttributes in a un-tracked 1171 # dictionary. The attributes are the ones listed in CommonEndpoints or 1172 # "keras_api" for keras-specific attributes. 1173 if not hasattr(layer, '_serialized_attributes'): 1174 with trackable.no_automatic_dependency_tracking_scope(layer): 1175 layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access 1176 1177 1178def _get_keras_attr(layer): 1179 return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR, 1180 None) 1181