• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel deserialization."""
16
17import os
18import re
19import types
20
21from google.protobuf import message
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import regularizers
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.optimizer_v2 import optimizer_v2
32from tensorflow.python.keras.protobuf import saved_metadata_pb2
33from tensorflow.python.keras.protobuf import versions_pb2
34from tensorflow.python.keras.saving import saving_utils
35from tensorflow.python.keras.saving.saved_model import constants
36from tensorflow.python.keras.saving.saved_model import json_utils
37from tensorflow.python.keras.saving.saved_model import utils
38from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
39from tensorflow.python.keras.utils import generic_utils
40from tensorflow.python.keras.utils import metrics_utils
41from tensorflow.python.keras.utils.generic_utils import LazyLoader
42from tensorflow.python.ops.ragged import ragged_tensor
43from tensorflow.python.platform import gfile
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.saved_model import load as tf_load
46from tensorflow.python.saved_model import loader_impl
47from tensorflow.python.saved_model import nested_structure_coder
48from tensorflow.python.saved_model import revived_types
49from tensorflow.python.training.tracking import base as trackable
50from tensorflow.python.training.tracking import data_structures
51from tensorflow.python.util import compat
52from tensorflow.python.util import nest
53
54# To avoid circular dependencies between keras/engine and keras/saving,
55# code in keras/saving must delay imports.
56
57# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
58# once the issue with copybara is fixed.
59# pylint:disable=g-inconsistent-quotes
60models_lib = LazyLoader("models_lib", globals(),
61                        "tensorflow.python.keras.models")
62base_layer = LazyLoader(
63    "base_layer", globals(),
64    "tensorflow.python.keras.engine.base_layer")
65layers_module = LazyLoader(
66    "layers_module", globals(),
67    "tensorflow.python.keras.layers")
68input_layer = LazyLoader(
69    "input_layer", globals(),
70    "tensorflow.python.keras.engine.input_layer")
71functional_lib = LazyLoader(
72    "functional_lib", globals(),
73    "tensorflow.python.keras.engine.functional")
74training_lib = LazyLoader(
75    "training_lib", globals(),
76    "tensorflow.python.keras.engine.training")
77training_lib_v1 = LazyLoader(
78    "training_lib_v1", globals(),
79    "tensorflow.python.keras.engine.training_v1")
80metrics = LazyLoader("metrics", globals(),
81                     "tensorflow.python.keras.metrics")
82recurrent = LazyLoader(
83    "recurrent", globals(),
84    "tensorflow.python.keras.layers.recurrent")
85# pylint:enable=g-inconsistent-quotes
86
87
88PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
89    CommonEndpoints.all_checkpointable_objects)
90PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
91
92
93def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
94  """Loads Keras objects from a SavedModel.
95
96  Any Keras layer or model saved to the SavedModel will be loaded back
97  as Keras objects. Other objects are loaded as regular trackable objects (same
98  as `tf.saved_model.load`).
99
100  Currently, Keras saving/loading only retains the Keras object's weights,
101  losses, and call function.
102
103  The loaded model can be re-compiled, but the original optimizer, compiled loss
104  functions, and metrics are not retained. This is temporary, and `model.save`
105  will soon be able to serialize compiled models.
106
107  Args:
108    path: Path to SavedModel.
109    compile: If true, compile the model after loading it.
110    options: Optional `tf.saved_model.LoadOptions` object that specifies
111      options for loading from SavedModel.
112
113
114  Returns:
115    Object loaded from SavedModel.
116  """
117  # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
118  # TODO(kathywu): Add code to load from objects that contain all endpoints
119
120  # Look for metadata file or parse the SavedModel
121  metadata = saved_metadata_pb2.SavedMetadata()
122  meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
123  object_graph_def = meta_graph_def.object_graph_def
124  path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
125  if gfile.Exists(path_to_metadata_pb):
126    try:
127      with gfile.GFile(path_to_metadata_pb, 'rb') as f:
128        file_content = f.read()
129      metadata.ParseFromString(file_content)
130    except message.DecodeError as e:
131      raise IOError('Cannot parse keras metadata {}: {}.'
132                    .format(path_to_metadata_pb, str(e)))
133  else:
134    logging.warning('SavedModel saved prior to TF 2.5 detected when loading '
135                    'Keras model. Please ensure that you are saving the model '
136                    'with model.save() or tf.keras.models.save_model(), *NOT* '
137                    'tf.saved_model.save(). To confirm, there should be a file '
138                    'named "keras_metadata.pb" in the SavedModel directory.')
139    _read_legacy_metadata(object_graph_def, metadata)
140
141  if not metadata.nodes:
142    # When there are no Keras objects, return the results from the core loader
143    return tf_load.load(path, options=options)
144
145  # Recreate layers and metrics using the info stored in the metadata.
146  keras_loader = KerasObjectLoader(metadata, object_graph_def)
147  keras_loader.load_layers(compile=compile)
148
149  # Generate a dictionary of all loaded nodes.
150  nodes_to_load = {'root': None}
151  for node_id, loaded_node in keras_loader.loaded_nodes.items():
152    nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
153  loaded = tf_load.load_partial(path, nodes_to_load, options=options)
154
155  # Finalize the loaded layers and remove the extra tracked dependencies.
156  keras_loader.finalize_objects()
157  keras_loader.del_tracking()
158
159  model = loaded['root']
160
161  # pylint: disable=protected-access
162  if isinstance(model, training_lib.Model) and compile:
163    # TODO(kathywu): Use compiled objects from SavedModel, instead of
164    # creating new objects from the training config.
165    training_config = model._serialized_attributes['metadata'].get(
166        'training_config', None)
167    if training_config is not None:
168      model.compile(**saving_utils.compile_args_from_training_config(
169          training_config), from_serialized=True)
170      saving_utils.try_build_compiled_arguments(model)
171      if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
172        if (model.optimizer.get_slot_names()):
173          logging.warning('Your optimizer uses slots. '
174                          'Slots cannot be restored from saved_model, '
175                          'as a result, your model is starting with  '
176                          'a new initialized optimizer.')
177    else:
178      logging.warning('No training configuration found in save file, so the '
179                      'model was *not* compiled. Compile it manually.')
180  # pylint: enable=protected-access
181
182  # Force variables and resources to initialize.
183  if not context.executing_eagerly():
184    sess = backend.get_session()  # Variables are initialized by this call.
185    sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))
186
187  return model
188
189
190def _read_legacy_metadata(object_graph_def, metadata):
191  """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
192  # Older SavedModels store the metadata directly in the proto instead of the
193  # separate pb file.
194  node_paths = _generate_object_paths(object_graph_def)
195  for node_id, proto in enumerate(object_graph_def.nodes):
196    if (proto.WhichOneof('kind') == 'user_object' and
197        proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
198      if not proto.user_object.metadata:
199        raise ValueError('Unable to create a Keras model from this SavedModel. '
200                         'This SavedModel was created with '
201                         '`tf.saved_model.save`, and lacks the Keras metadata.'
202                         'Please save your Keras model by calling `model.save`'
203                         'or `tf.keras.models.save_model`.')
204      metadata.nodes.add(
205          node_id=node_id,
206          node_path=node_paths[node_id],
207          version=versions_pb2.VersionDef(
208              producer=1, min_consumer=1, bad_consumers=[]),
209          identifier=proto.user_object.identifier,
210          metadata=proto.user_object.metadata)
211
212
213def _generate_object_paths(object_graph_def):
214  """Traverses through an ObjectGraphDef and builds a map of all node paths."""
215  paths = {0: 'root'}
216  nodes_to_visit = [0]
217
218  while nodes_to_visit:
219    current_node = nodes_to_visit.pop()
220    current_path = paths[current_node]
221    for reference in object_graph_def.nodes[current_node].children:
222      if reference.node_id in paths:
223        continue
224      paths[reference.node_id] = '{}.{}'.format(current_path,
225                                                reference.local_name)
226      nodes_to_visit.append(reference.node_id)
227
228  return paths
229
230
231def _is_graph_network(layer):
232  """Determines whether the layer is a graph network."""
233  # pylint: disable=protected-access
234  if isinstance(layer, RevivedNetwork):
235    return False
236  elif isinstance(layer, functional_lib.Functional):
237    return (layer._is_graph_network or
238            isinstance(layer, models_lib.Sequential))
239  return False
240
241
242class KerasObjectLoader(object):
243  """Loader that recreates Keras objects (e.g. layers, models).
244
245  Layers and models are revived from either the config or SavedModel following
246  these rules:
247  1. If object is a graph network (i.e. Sequential or Functional) then it will
248     be initialized using the structure from the config only after the children
249     layers have been created. Graph networks must be initialized with inputs
250     and outputs, so all child layers must be created beforehand.
251  2. If object's config exists and the class can be found, then revive from
252     config.
253  3. Object may have already been created if its parent was revived from config.
254     In this case, do nothing.
255  4. If nothing of the above applies, compose the various artifacts from the
256     SavedModel to create a subclassed layer or model. At this time, custom
257     metrics are not supported.
258
259  """
260
261  def __init__(self, metadata, object_graph_def):
262    self._metadata = {x.node_id: x for x in metadata.nodes}
263    self._proto = object_graph_def
264
265    self._node_paths = {node_data.node_id: node_data.node_path
266                        for node_data in metadata.nodes}
267    self.loaded_nodes = {}  # Maps node path -> loaded node
268
269    # Store all node ids that have already been traversed when tracking nodes
270    # that were recreated from the config.
271    self._traversed_nodes_from_config = set()
272
273    # Maps model id -> (blank model obj, list of child layer or their node ids)
274    # This tracks all layers in functional and sequential models. These models
275    # are only reconstructed after all of their child layers have been created.
276    self.model_layer_dependencies = {}
277    self._models_to_reconstruct = []
278
279  def del_tracking(self):
280    """Removes tracked references that are only used when loading the model."""
281    # Now that the node object has been fully loaded, and the checkpoint has
282    # been restored, the object no longer needs to track objects added from
283    # SerializedAttributes. (Note that saving a training checkpoint still
284    # functions correctly, because layers and variables are tracked separately
285    # by the Layer object.)
286    # TODO(kathywu): Instead of outright deleting these nodes (which would
287    # make restoring from a different checkpoint tricky), mark them as extra
288    # dependencies that are OK to overwrite.
289    for node in self.loaded_nodes.values():
290      node = node[0]
291      if not isinstance(node, base_layer.Layer):
292        # Loaded nodes can contain other trackable objects created when
293        # loading layers from the config, such as variables.
294        continue
295      for name in PUBLIC_ATTRIBUTES:
296        node._delete_tracking(name)  # pylint: disable=protected-access
297
298      if isinstance(node, functional_lib.Functional):
299        # Delete the temporary layer dependencies, which were used to restore
300        # the checkpointed values. When the model is live, the user can delete
301        # or add layers to the model at any time, so these layer dependencies
302        # may be obsolete.
303        dependencies = list(node._self_unconditional_dependency_names)  # pylint: disable=protected-access
304        for name in dependencies:
305          if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
306            node._delete_tracking(name)  # pylint: disable=protected-access
307
308  def _add_children_recreated_from_config(self, obj, proto, node_id):
309    """Recursively records objects recreated from config."""
310    # pylint: disable=protected-access
311    if node_id in self._traversed_nodes_from_config:
312      return
313
314    parent_path = self._node_paths[node_id]
315    self._traversed_nodes_from_config.add(node_id)
316    obj._maybe_initialize_trackable()
317    if isinstance(obj, base_layer.Layer) and not obj.built:
318      metadata = json_utils.decode(self._metadata[node_id].metadata)
319      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
320
321    # Create list of all possible children
322    children = []
323    # Look for direct children
324    for reference in proto.children:
325      obj_child = obj._lookup_dependency(reference.local_name)
326      children.append((obj_child, reference.node_id, reference.local_name))
327
328    # Add metrics that may have been added to the layer._metrics list.
329    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
330    # SavedModels created after Tf 2.2.
331    metric_list_node_id = self._search_for_child_node(
332        node_id, [constants.KERAS_ATTR, 'layer_metrics'])
333    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
334      obj_metrics = {m.name: m for m in obj._metrics}
335      for reference in self._proto.nodes[metric_list_node_id].children:
336        metric = obj_metrics.get(reference.local_name)
337        if metric is not None:
338          metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
339                                                     reference.local_name)
340          children.append((metric, reference.node_id, metric_path))
341
342    for (obj_child, child_id, child_name) in children:
343      child_proto = self._proto.nodes[child_id]
344
345      if not isinstance(obj_child, trackable.Trackable):
346        continue
347      if (child_proto.user_object.identifier in
348          revived_types.registered_identifiers()):
349        setter = revived_types.get_setter(child_proto.user_object)
350      elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
351        setter = _revive_setter
352      else:
353        setter = setattr
354        # pylint: enable=protected-access
355
356      if child_id in self.loaded_nodes:
357        if self.loaded_nodes[child_id][0] is not obj_child:
358          # This means that the same trackable object is referenced by two
359          # different objects that were recreated from the config.
360          logging.warning(
361              'Looks like there is an object (perhaps variable or '
362              'layer) that is shared between different layers/models. '
363              'This may cause issues when restoring the variable '
364              'values. Object: {}'.format(obj_child))
365        continue
366
367      # Overwrite variable names with the ones saved in the SavedModel.
368      if (child_proto.WhichOneof('kind') == 'variable' and
369          child_proto.variable.name):
370        obj_child._handle_name = child_proto.variable.name + ':0'  # pylint: disable=protected-access
371
372      if isinstance(obj_child, data_structures.TrackableDataStructure):
373        setter = lambda *args: None
374
375      child_path = '{}.{}'.format(parent_path, child_name)
376      self._node_paths[child_id] = child_path
377      self._add_children_recreated_from_config(
378          obj_child, child_proto, child_id)
379      self.loaded_nodes[child_id] = obj_child, setter
380
381  def load_layers(self, compile=True):  # pylint: disable=redefined-builtin
382    """Load all layer nodes from the metadata."""
383    # Load metrics after models and layers, since it's likely that models
384    # and layers will create the metric when initialized (this avoids wasting
385    # time by creating objects multiple times).
386    metric_list = []
387    for node_metadata in self._metadata.values():
388      if node_metadata.identifier == constants.METRIC_IDENTIFIER:
389        metric_list.append(node_metadata)
390        continue
391
392      self.loaded_nodes[node_metadata.node_id] = self._load_layer(
393          node_metadata.node_id, node_metadata.identifier,
394          node_metadata.metadata)
395
396    for node_metadata in metric_list:
397      try:
398        self.loaded_nodes[node_metadata.node_id] = self._load_layer(
399            node_metadata.node_id, node_metadata.identifier,
400            node_metadata.metadata)
401      except ValueError:
402        # Metrics are only needed when the model is compiled later. We ignore
403        # errors when trying to load custom metrics when `compile=False` until
404        # custom metrics are serialized properly (b/135550038).
405        if compile:
406          raise
407        logging.warning('Unable to restore custom metric. Please ensure that '
408                        'the layer implements `get_config` and `from_config` '
409                        'when saving. In addition, please use the '
410                        '`custom_objects` arg when calling `load_model()`.')
411
412  def _load_layer(self, node_id, identifier, metadata):
413    """Load a single layer from a SavedUserObject proto."""
414    metadata = json_utils.decode(metadata)
415
416    # If node was already created
417    if node_id in self.loaded_nodes:
418      node, setter = self.loaded_nodes[node_id]
419
420      # Revive setter requires the object to have a `_serialized_attributes`
421      # property. Add it here.
422      _maybe_add_serialized_attributes(node, metadata)
423
424      config = metadata.get('config')
425      if _is_graph_network(node) and generic_utils.validate_config(config):
426        child_nodes = self._get_child_layer_node_ids(node_id)
427        self.model_layer_dependencies[node_id] = (node, child_nodes)
428        if not child_nodes:
429          self._models_to_reconstruct.append(node_id)
430      return node, setter
431
432    # Detect whether this object can be revived from the config. If not, then
433    # revive from the SavedModel instead.
434    obj, setter = self._revive_from_config(identifier, metadata, node_id)
435    if obj is None:
436      obj, setter = revive_custom_object(identifier, metadata)
437
438    # Add an attribute that stores the extra functions/objects saved in the
439    # SavedModel. Most of these functions/objects are ignored, but some are
440    # used later in the loading process (e.g. the list of regularization
441    # losses, or the training config of compiled models).
442    _maybe_add_serialized_attributes(obj, metadata)
443    return obj, setter
444
445  def _revive_from_config(self, identifier, metadata, node_id):
446    """Revives a layer/model from config, or returns None."""
447    if identifier == constants.METRIC_IDENTIFIER:
448      obj = self._revive_metric_from_config(metadata)
449    else:
450      obj = (
451          self._revive_graph_network(identifier, metadata, node_id) or
452          self._revive_layer_or_model_from_config(metadata, node_id))
453
454    if obj is None:
455      return None, None
456
457    setter = self._config_node_setter(_revive_setter)
458    self._add_children_recreated_from_config(
459        obj, self._proto.nodes[node_id], node_id)
460    return obj, setter
461
462  def _revive_graph_network(self, identifier, metadata, node_id):
463    """Revives a graph network from config."""
464    # Determine whether the metadata contains information for reviving a
465    # functional or Sequential model.
466    config = metadata.get('config')
467    if not generic_utils.validate_config(config):
468      return None
469
470    class_name = compat.as_str(metadata['class_name'])
471    if generic_utils.get_registered_object(class_name) is not None:
472      return None
473    model_is_functional_or_sequential = (
474        metadata.get('is_graph_network', False) or
475        class_name == 'Sequential' or
476        class_name == 'Functional')
477    if not model_is_functional_or_sequential:
478      return None
479
480    # Revive functional and sequential models as blank model objects for now (
481    # must be initialized to enable setattr tracking and attribute caching).
482    # Reconstruction of the network is deferred until all of the model's layers
483    # have been revived.
484    if class_name == 'Sequential':
485      model = models_lib.Sequential(name=config['name'])
486    # The model is a custom Sequential model.
487    elif identifier == constants.SEQUENTIAL_IDENTIFIER:
488      # Uses the custom class name, since the config does not have one.
489      model = models_lib.Sequential(name=class_name)
490    else:
491      model = models_lib.Functional(
492          inputs=[], outputs=[], name=config['name'])
493
494    # Record this model and its layers. This will later be used to reconstruct
495    # the model.
496    layers = self._get_child_layer_node_ids(node_id)
497    self.model_layer_dependencies[node_id] = (model, layers)
498    if not layers:
499      self._models_to_reconstruct.append(node_id)
500    return model
501
502  def _revive_layer_or_model_from_config(self, metadata, node_id):
503    """Revives a layer/custom model from config; returns None if infeasible."""
504    # Check that the following requirements are met for reviving from config:
505    #    1. Object can be deserialized from config.
506    #    2. If the object needs to be built, then the build input shape can be
507    #       found.
508    class_name = metadata.get('class_name')
509    config = metadata.get('config')
510    shared_object_id = metadata.get('shared_object_id')
511    must_restore_from_config = metadata.get('must_restore_from_config')
512    if not generic_utils.validate_config(config):
513      return None
514
515    try:
516      obj = layers_module.deserialize(
517          generic_utils.serialize_keras_class_and_config(
518              class_name, config, shared_object_id=shared_object_id))
519    except ValueError:
520      if must_restore_from_config:
521        raise RuntimeError(
522            'Unable to restore a layer of class {cls}. Layers of '
523            'class {cls} require that the class be provided to '
524            'the model loading code, either by registering the '
525            'class using @keras.utils.register_keras_serializable '
526            'on the class def and including that file in your '
527            'program, or by passing the class in a '
528            'keras.utils.CustomObjectScope that wraps this load '
529            'call.'.format(cls=class_name))
530      else:
531        return None
532
533    # Use the dtype, name, and trainable status. Often times these are not
534    # specified in custom configs, so retrieve their values from the metadata.
535    # pylint: disable=protected-access
536    obj._name = metadata['name']
537    if metadata.get('trainable') is not None:
538      obj.trainable = metadata['trainable']
539    if metadata.get('dtype') is not None:
540      obj._set_dtype_policy(metadata['dtype'])
541    if metadata.get('stateful') is not None:
542      obj.stateful = metadata['stateful']
543    # Restore model save spec for subclassed models. (layers do not store a
544    # SaveSpec)
545    if isinstance(obj, training_lib.Model):
546      save_spec = metadata.get('save_spec')
547      if save_spec is not None:
548        obj._set_save_spec(save_spec)
549    # pylint: enable=protected-access
550
551    build_input_shape = metadata.get('build_input_shape')
552    built = self._try_build_layer(obj, node_id, build_input_shape)
553
554    if not built:
555      # If the layer cannot be built, revive a custom layer instead.
556      return None
557    return obj
558
559  def _revive_metric_from_config(self, metadata):
560    """Revives a metric object using the config saved in the metadata."""
561    class_name = compat.as_str(metadata['class_name'])
562    config = metadata.get('config')
563
564    if not generic_utils.validate_config(config):
565      return None
566
567    try:
568      obj = metrics.deserialize(
569          generic_utils.serialize_keras_class_and_config(class_name, config))
570    except ValueError:
571      return None
572
573    build_input_shape = metadata.get('build_input_shape')
574    if build_input_shape is not None and hasattr(obj, '_build'):
575      obj._build(build_input_shape)  # pylint: disable=protected-access
576
577    return obj
578
579  def _try_build_layer(self, obj, node_id, build_input_shape):
580    """Attempts to build the layer."""
581    if obj.built or hasattr(obj.build, '_is_default'):
582      obj.built = True
583      return True
584
585    if build_input_shape is None:
586      build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
587
588    if build_input_shape is not None:
589      obj.build(build_input_shape)
590      base_layer.Layer.build(obj, build_input_shape)
591      return True
592
593    return False
594
595  def _load_edges(self):
596    """Add edges for all nodes that are not waiting on initialization."""
597    for node_id, proto in enumerate(self._proto.nodes):
598      if node_id not in self.model_layer_dependencies:
599        self._add_object_graph_edges(proto, node_id)
600
601  def get_path(self, node_id):
602    return self._node_paths[node_id]
603
604  def finalize_objects(self):
605    """Finish setting up Keras objects.
606
607    This function is executed after all objects and functions have been created.
608    Call functions and losses are attached to each layer, and once all layers
609    have been fully set up, graph networks are initialized.
610
611    Subclassed models that are revived from the SavedModel are treated like
612    layers, and have their call/loss functions attached here.
613    """
614    # Finish setting up layers and subclassed models. This step attaches call
615    # functions and losses to each object, and sets model inputs/outputs.
616    layers_revived_from_config = []
617    layers_revived_from_saved_model = []
618    for node_id, (node, _) in self.loaded_nodes.items():
619      if (not isinstance(node, base_layer.Layer) or
620          # Don't finalize models until all layers have finished loading.
621          node_id in self.model_layer_dependencies):
622        continue
623
624      self._unblock_model_reconstruction(node_id, node)
625
626      if isinstance(node, input_layer.InputLayer):
627        continue
628      elif isinstance(node, metrics.Metric):
629        continue
630
631      if isinstance(node, (RevivedLayer, RevivedInputLayer)):
632        layers_revived_from_saved_model.append(node)
633      else:
634        layers_revived_from_config.append(node)
635
636    _finalize_saved_model_layers(layers_revived_from_saved_model)
637    _finalize_config_layers(layers_revived_from_config)
638
639    # Initialize graph networks, now that layer dependencies have been resolved.
640    self._reconstruct_all_models()
641
642  def _unblock_model_reconstruction(self, layer_id, layer):
643    """Removes layer from blocking model reconstruction."""
644    for model_id, v in self.model_layer_dependencies.items():
645      _, layers = v
646      if layer_id not in layers:
647        continue
648      layers[layers.index(layer_id)] = layer
649      if all(isinstance(x, base_layer.Layer) for x in layers):
650        self._models_to_reconstruct.append(model_id)
651
652  def _reconstruct_all_models(self):
653    """Reconstructs the network structure of all models."""
654    all_initialized_models = set()
655    while self._models_to_reconstruct:
656      model_id = self._models_to_reconstruct.pop(0)
657      all_initialized_models.add(model_id)
658      model, layers = self.model_layer_dependencies[model_id]
659      self._reconstruct_model(model_id, model, layers)
660      _finalize_config_layers([model])
661
662    if all_initialized_models != set(self.model_layer_dependencies.keys()):
663      # This should not happen.
664      uninitialized_model_ids = (
665          set(self.model_layer_dependencies.keys()) - all_initialized_models)
666      uninitialized_model_names = [
667          self.model_layer_dependencies[model_id][0].name
668          for model_id in uninitialized_model_ids]
669      raise ValueError('Error when loading from SavedModel -- the following '
670                       'models could not be initialized: {}'
671                       .format(uninitialized_model_names))
672
673  def _reconstruct_model(self, model_id, model, layers):
674    """Reconstructs the network structure."""
675    config = json_utils.decode(self._metadata[model_id].metadata)['config']
676
677    # Set up model inputs
678    if model.inputs:
679      # Inputs may already be created if the model is instantiated in another
680      # object's __init__.
681      pass
682    elif isinstance(model, models_lib.Sequential):
683      if not layers or not isinstance(layers[0], input_layer.InputLayer):
684        if config['layers'][0]['class_name'] == 'InputLayer':
685          layers.insert(0, input_layer.InputLayer.from_config(
686              config['layers'][0]['config']))
687        elif 'batch_input_shape' in config['layers'][0]['config']:
688          batch_input_shape = config['layers'][0]['config']['batch_input_shape']
689          layers.insert(0, input_layer.InputLayer(
690              input_shape=batch_input_shape[1:],
691              batch_size=batch_input_shape[0],
692              dtype=layers[0].dtype,
693              name=layers[0].name + '_input'))
694      model.__init__(layers, name=config['name'])
695      if not model.inputs:
696        first_layer = self._get_child_layer_node_ids(model_id)[0]
697        input_specs = self._infer_inputs(first_layer)
698        input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
699        model._set_inputs(input_specs)  # pylint: disable=protected-access
700        if not model.built and not isinstance(input_specs, dict):
701          model.build(input_shapes)
702    else:  # Reconstruct functional model
703      (inputs, outputs,
704       created_layers) = functional_lib.reconstruct_from_config(
705           config, created_layers={layer.name: layer for layer in layers})
706      model.__init__(inputs, outputs, name=config['name'])
707      functional_lib.connect_ancillary_layers(model, created_layers)
708
709    # Set model dtype.
710    _set_network_attributes_from_metadata(model)
711
712    # Unblock models that are dependent on this model.
713    self._unblock_model_reconstruction(model_id, model)
714
715  def _get_child_layer_node_ids(self, node_id):
716    """Returns the node ids of each layer in a Sequential/Functional model."""
717    # Sequential and Functional track layers with names following the format
718    # "layer-N". Use this to generate the list of layers.
719    num_layers = 0
720    child_layers = {}
721    pattern = re.compile('layer-(\\d+)')
722
723    for child in self._proto.nodes[node_id].children:
724      m = pattern.match(child.local_name)
725      if m is None:
726        continue
727      layer_n = int(m.group(1))
728      num_layers = max(layer_n + 1, num_layers)
729      child_layers[layer_n] = child.node_id
730
731    ordered = []
732    for n in range(num_layers):
733      child = child_layers.get(n)
734      if child is None:
735        break
736      ordered.append(child)
737    return ordered
738
739  def _search_for_child_node(self, parent_id, path_to_child):
740    """Returns node id of child node.
741
742    A helper method for traversing the object graph proto.
743
744    As an example, say that the object graph proto in the SavedModel contains an
745    object with the following child and grandchild attributes:
746
747    `parent.child_a.child_b`
748
749    This method can be used to retrieve the node id of `child_b` using the
750    parent's node id by calling:
751
752    `_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
753
754    Args:
755      parent_id: node id of parent node
756      path_to_child: list of children names.
757
758    Returns:
759      node_id of child, or None if child isn't found.
760    """
761    if not path_to_child:
762      return parent_id
763
764    for child in self._proto.nodes[parent_id].children:
765      if child.local_name == path_to_child[0]:
766        return self._search_for_child_node(child.node_id, path_to_child[1:])
767    return None
768
769  def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
770    """Infers input shape of layer from SavedModel functions."""
771    coder = nested_structure_coder.StructureCoder()
772    call_fn_id = self._search_for_child_node(
773        layer_node_id, ['call_and_return_all_conditional_losses'])
774    if call_fn_id is None:
775      return None
776
777    concrete_functions = (
778        self._proto.nodes[call_fn_id].function.concrete_functions)
779    if not concrete_functions:
780      return None
781    call_fn_name = concrete_functions[0]
782    call_fn_proto = self._proto.concrete_functions[call_fn_name]
783    structured_input_signature = coder.decode_proto(
784        call_fn_proto.canonicalized_input_signature)
785    inputs = structured_input_signature[0][0]
786    if convert_to_shapes:
787      return nest.map_structure(lambda spec: spec.shape, inputs)
788    else:
789      return inputs
790
791  def _config_node_setter(self, setter):
792    """Creates edges for nodes that are recreated from config."""
793    def setattr_wrapper(obj, name, value):
794      # Avoid overwriting attributes of objects recreated from the config.
795      if obj._lookup_dependency(name) is None:  # pylint: disable=protected-access
796        setter(obj, name, value)
797    return setattr_wrapper
798
799
800def _finalize_saved_model_layers(layers):
801  """Runs the final steps of loading Keras Layers from SavedModel."""
802  # pylint: disable=protected-access
803  # 1. Set up call functions for all layers initialized from the SavedModel (
804  # and not the config)
805  for layer in layers:
806    layer.built = True
807    layer_call = getattr(_get_keras_attr(layer),
808                         'call_and_return_conditional_losses', None)
809    if layer_call and layer_call.concrete_functions:
810      layer.call = utils.use_wrapped_call(
811          layer, layer_call, return_method=True)
812      expects_training_arg = layer._serialized_attributes['metadata'][
813          'expects_training_arg']
814      if 'training' in layer_call.function_spec.arg_names:
815        # This could change the value of `expects_training_arg` if this layer
816        # doesn't expect a training arg, but has a child layer that does.
817        expects_training_arg = True
818      layer._init_call_fn_args(expects_training_arg)
819    else:
820      layer.call = types.MethodType(
821          _unable_to_call_layer_due_to_serialization_issue, layer)
822
823  for layer in layers:
824    # 2. Set model inputs and outputs.
825    if isinstance(layer, RevivedNetwork):
826      _set_network_attributes_from_metadata(layer)
827
828      if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
829        call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
830        if not call_fn.concrete_functions:
831          continue
832        if call_fn.input_signature is None:
833          inputs = infer_inputs_from_restored_call_function(call_fn)
834        else:
835          inputs = call_fn.input_signature[0]
836        layer._set_inputs(inputs)  # pylint: disable=protected-access
837
838    # 3. Add losses that aren't generated by the layer.call function.
839    _restore_layer_unconditional_losses(layer)
840    _restore_layer_activation_loss(layer)
841
842    # 4. Restore metrics list
843    _restore_layer_metrics(layer)
844
845  # pylint: enable=protected-access
846
847
848def _unable_to_call_layer_due_to_serialization_issue(
849    layer, *unused_args, **unused_kwargs):
850  """Replaces the `layer.call` if the layer was not fully serialized.
851
852  Keras Model/Layer serialization is relatively relaxed because SavedModels
853  are not always loaded back as keras models. Thus, when there is an issue
854  tracing a non-signature function, a warning is logged instead of raising an
855  error. This results in a SavedModel where the model's call function is saved,
856  but the internal layer call functions are not.
857
858  When deserialized with `tf.keras.models.load_model`, the internal layers
859  which do not have serialized call functions should raise an error when called.
860
861  Args:
862    layer: Layer without the serialized call function.
863
864  Raises:
865    ValueError
866  """
867
868  raise ValueError(
869      'Cannot call custom layer {} of type {}, because the call function was '
870      'not serialized to the SavedModel.'
871      'Please try one of the following methods to fix this issue:'
872      '\n\n(1) Implement `get_config` and `from_config` in the layer/model '
873      'class, and pass the object to the `custom_objects` argument when '
874      'loading the model. For more details, see: '
875      'https://www.tensorflow.org/guide/keras/save_and_serialize'
876      '\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
877      'and not `__call__`. The input shape and dtype will be automatically '
878      'recorded when the object is called, and used when saving. To manually '
879      'specify the input shape/dtype, decorate the call function with '
880      '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
881
882
883def _finalize_config_layers(layers):
884  """Runs the final steps of loading Keras Layers from config."""
885  for layer in layers:
886    # It is assumed that layers define their unconditional losses after being
887    # recreated from the config and built. The exceptions to this
888    # are Functional and Sequential models, which only store conditional losses
889    # (losses dependent on the inputs) in the config. Unconditional losses like
890    # weight regularization must be revived from the SavedModel.
891    if _is_graph_network(layer):
892      _restore_layer_unconditional_losses(layer)
893
894    # Some layers, like Dense, record their activation loss function in the
895    # config. However, not all layers do this, so the activation loss may be
896    # missing when restored from the config/hdf5.
897    # TODO(kathywu): Investigate ways to improve the config to ensure consistent
898    # loading behavior between HDF5 and SavedModel.
899    _restore_layer_activation_loss(layer)
900
901    # Restore metrics list.
902    _restore_layer_metrics(layer)
903
904    # Restore RNN layer states.
905    if (isinstance(layer, recurrent.RNN) and
906        layer.stateful and
907        hasattr(_get_keras_attr(layer), 'states')):
908      layer.states = getattr(_get_keras_attr(layer), 'states', None)
909      for variable in nest.flatten(layer.states):
910        backend.track_variable(variable)
911
912    # Perform any layer defined finalization of the layer state.
913    layer.finalize_state()
914
915
916def _finalize_metric(metric):
917  metric.update_state = types.MethodType(metrics_utils.update_state_wrapper(
918      metric.keras_api.update_state), metric)
919  metric.result = metric.keras_api.result
920
921
922def _restore_layer_unconditional_losses(layer):
923  """Restore unconditional losses from SavedModel."""
924  if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'):
925    losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', [])
926  else:
927    # Some earlier SavedModels may not have layer_regularization_losses
928    # serialized separately. Fall back to using the regularization_losses
929    # list if it does not exist.
930    losses = layer._serialized_attributes.get('regularization_losses', [])  # pylint: disable=protected-access
931  for loss in losses:
932    layer.add_loss(loss)
933
934
935def _restore_layer_activation_loss(layer):
936  """Restore actiation loss from SavedModel."""
937  # Use wrapped activity regularizer function if the layer's activity
938  # regularizer wasn't created during initialization.
939  activity_regularizer = getattr(_get_keras_attr(layer),
940                                 'activity_regularizer_fn', None)
941  if activity_regularizer and not layer.activity_regularizer:
942    try:
943      layer.activity_regularizer = activity_regularizer
944    except AttributeError:
945      # This may happen if a layer wrapper is saved with an activity
946      # regularizer. The wrapper object's activity regularizer is unsettable.
947      pass
948
949
950def revive_custom_object(identifier, metadata):
951  """Revives object from SavedModel."""
952  if ops.executing_eagerly_outside_functions():
953    model_class = training_lib.Model
954  else:
955    model_class = training_lib_v1.Model
956
957  revived_classes = {
958      constants.INPUT_LAYER_IDENTIFIER: (
959          RevivedInputLayer, input_layer.InputLayer),
960      constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
961      constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
962      constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
963      constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
964  }
965  parent_classes = revived_classes.get(identifier, None)
966
967  if parent_classes is not None:
968    parent_classes = revived_classes[identifier]
969    revived_cls = type(
970        compat.as_str(metadata['class_name']), parent_classes, {})
971    return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
972  else:
973    raise ValueError('Unable to restore custom object of type {} currently. '
974                     'Please make sure that the layer implements `get_config`'
975                     'and `from_config` when saving. In addition, please use '
976                     'the `custom_objects` arg when calling `load_model()`.'
977                     .format(identifier))
978
979
980def _restore_layer_metrics(layer):
981  metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {})
982  layer_metrics = {m.name: m for m in layer._metrics}  # pylint: disable=protected-access
983  for name, metric in metrics_list.items():
984    if name not in layer_metrics:
985      # Metrics may be added during initialization/building of custom layers.
986      layer._metrics.append(metric)  # pylint: disable=protected-access
987
988
989# TODO(kathywu): Centrally define keys and functions for both  serialization and
990# deserialization.
991class RevivedLayer(object):
992  """Keras layer loaded from a SavedModel."""
993
994  @classmethod
995  def _init_from_metadata(cls, metadata):
996    """Create revived layer from metadata stored in the SavedModel proto."""
997    init_args = dict(
998        name=metadata['name'],
999        trainable=metadata['trainable'])
1000    if metadata.get('dtype') is not None:
1001      init_args['dtype'] = metadata['dtype']
1002    if metadata.get('batch_input_shape') is not None:
1003      init_args['batch_input_shape'] = metadata['batch_input_shape']
1004
1005    revived_obj = cls(**init_args)
1006
1007    with utils.no_automatic_dependency_tracking_scope(revived_obj):
1008      # pylint:disable=protected-access
1009      revived_obj._expects_training_arg = metadata['expects_training_arg']
1010      config = metadata.get('config')
1011      if generic_utils.validate_config(config):
1012        revived_obj._config = config
1013      if metadata.get('input_spec') is not None:
1014        revived_obj.input_spec = recursively_deserialize_keras_object(
1015            metadata['input_spec'],
1016            module_objects={'InputSpec': input_spec.InputSpec})
1017      if metadata.get('activity_regularizer') is not None:
1018        revived_obj.activity_regularizer = regularizers.deserialize(
1019            metadata['activity_regularizer'])
1020      if metadata.get('_is_feature_layer') is not None:
1021        revived_obj._is_feature_layer = metadata['_is_feature_layer']
1022      if metadata.get('stateful') is not None:
1023        revived_obj.stateful = metadata['stateful']
1024      # pylint:enable=protected-access
1025
1026    return revived_obj, _revive_setter
1027
1028  @property
1029  def keras_api(self):
1030    return self._serialized_attributes.get(constants.KERAS_ATTR, None)
1031
1032  def get_config(self):
1033    if hasattr(self, '_config'):
1034      return self._config
1035    else:
1036      raise NotImplementedError
1037
1038
1039def _revive_setter(layer, name, value):
1040  """Setter function that saves some attributes to separate dictionary."""
1041  # Many attributes in the SavedModel conflict with properties defined in
1042  # Layer and Model. Save these attributes to a separate dictionary.
1043  if name in PUBLIC_ATTRIBUTES:
1044    # pylint: disable=protected-access
1045    if isinstance(value, trackable.Trackable):
1046      layer._track_trackable(value, name=name)
1047    layer._serialized_attributes[name] = value
1048    # pylint: enable=protected-access
1049  elif (isinstance(layer, functional_lib.Functional) and
1050        re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
1051    # Edges named "layer-n" or "layer_with_weights-n", which are tracked in
1052    # network._track_layers, should not be added as an attribute. They should
1053    # be temporarily added as a dependency so that checkpointed values can be
1054    # restored. These dependencies are manually deleted in
1055    # KerasObjectLoader.del_tracking.
1056
1057    # Set `overwrite=True` in the case that `layer` already tracks a different
1058    # layer-n. This may cause variable values to not be loaded properly in the
1059    # original layer-n, but we already warn the users about this
1060    # (ctrl-f "shared between different layers/models").
1061    layer._track_trackable(value, name, overwrite=True)  # pylint: disable=protected-access
1062  elif getattr(layer, name, None) is not None:
1063    # Don't overwrite already defined attributes.
1064    pass
1065  else:
1066    setattr(layer, name, value)
1067
1068
1069class RevivedInputLayer(object):
1070  """InputLayer loaded from a SavedModel."""
1071
1072  @classmethod
1073  def _init_from_metadata(cls, metadata):
1074    """Revives the saved InputLayer from the Metadata."""
1075    init_args = dict(
1076        name=metadata['name'],
1077        dtype=metadata['dtype'],
1078        sparse=metadata['sparse'],
1079        ragged=metadata['ragged'],
1080        batch_input_shape=metadata['batch_input_shape'])
1081    revived_obj = cls(**init_args)
1082    with utils.no_automatic_dependency_tracking_scope(revived_obj):
1083      revived_obj._config = metadata['config']  # pylint:disable=protected-access
1084
1085    return revived_obj, setattr
1086
1087  def get_config(self):
1088    return self._config
1089
1090
1091def recursively_deserialize_keras_object(config, module_objects=None):
1092  """Deserialize Keras object from a nested structure."""
1093  if isinstance(config, dict):
1094    if 'class_name' in config:
1095      return generic_utils.deserialize_keras_object(
1096          config, module_objects=module_objects)
1097    else:
1098      return {key: recursively_deserialize_keras_object(config[key],
1099                                                        module_objects)
1100              for key in config}
1101  if isinstance(config, (tuple, list)):
1102    return [recursively_deserialize_keras_object(x, module_objects)
1103            for x in config]
1104  else:
1105    raise ValueError('Unable to decode config: {}'.format(config))
1106
1107
1108def get_common_shape(x, y):
1109  """Find a `TensorShape` that is compatible with both `x` and `y`."""
1110  if x is None != y is None:
1111    raise RuntimeError(
1112        'Cannot find a common shape when LHS shape is None but RHS shape '
1113        'is not (or vice versa): %s vs. %s' % (x, y))
1114  if x is None:
1115    return None  # The associated input was not a Tensor, no shape generated.
1116  if not isinstance(x, tensor_shape.TensorShape):
1117    raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
1118  if not isinstance(y, tensor_shape.TensorShape):
1119    raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
1120  if x.rank != y.rank or x.rank is None:
1121    return tensor_shape.TensorShape(None)
1122  dims = []
1123  for dim_x, dim_y in zip(x.dims, y.dims):
1124    if (dim_x != dim_y
1125        or tensor_shape.dimension_value(dim_x) is None
1126        or tensor_shape.dimension_value(dim_y) is None):
1127      dims.append(None)
1128    else:
1129      dims.append(tensor_shape.dimension_value(dim_x))
1130  return tensor_shape.TensorShape(dims)
1131
1132
1133def infer_inputs_from_restored_call_function(fn):
1134  """Returns TensorSpec of inputs from a restored call function.
1135
1136  Args:
1137    fn: Restored layer call function. It is assumed that `fn` has at least
1138        one concrete function and that the inputs are in the first argument.
1139
1140  Returns:
1141    TensorSpec of call function inputs.
1142  """
1143  def common_spec(x, y):
1144    common_shape = get_common_shape(x.shape, y.shape)
1145    if isinstance(x, sparse_tensor.SparseTensorSpec):
1146      return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
1147    elif isinstance(x, ragged_tensor.RaggedTensorSpec):
1148      return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype)
1149    return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
1150
1151  spec = fn.concrete_functions[0].structured_input_signature[0][0]
1152  for concrete in fn.concrete_functions[1:]:
1153    spec2 = concrete.structured_input_signature[0][0]
1154    spec = nest.map_structure(common_spec, spec, spec2)
1155  return spec
1156
1157
1158class RevivedNetwork(RevivedLayer):
1159  """Keras network of layers loaded from a SavedModel."""
1160
1161  @classmethod
1162  def _init_from_metadata(cls, metadata):
1163    """Create revived network from metadata stored in the SavedModel proto."""
1164    revived_obj = cls(name=metadata['name'])
1165
1166    # Store attributes revived from SerializedAttributes in a un-tracked
1167    # dictionary. The attributes are the ones listed in CommonEndpoints or
1168    # "keras_api" for keras-specific attributes.
1169    with utils.no_automatic_dependency_tracking_scope(revived_obj):
1170      # pylint:disable=protected-access
1171      revived_obj._expects_training_arg = metadata['expects_training_arg']
1172      config = metadata.get('config')
1173      if generic_utils.validate_config(config):
1174        revived_obj._config = config
1175
1176      if metadata.get('activity_regularizer') is not None:
1177        revived_obj.activity_regularizer = regularizers.deserialize(
1178            metadata['activity_regularizer'])
1179      # pylint:enable=protected-access
1180
1181    return revived_obj, _revive_setter  # pylint:disable=protected-access
1182
1183
1184def _set_network_attributes_from_metadata(revived_obj):
1185  """Sets attributes recorded in the metadata."""
1186  with utils.no_automatic_dependency_tracking_scope(revived_obj):
1187    # pylint:disable=protected-access
1188    metadata = revived_obj._serialized_attributes['metadata']
1189    if metadata.get('dtype') is not None:
1190      revived_obj._set_dtype_policy(metadata['dtype'])
1191    revived_obj._trainable = metadata['trainable']
1192    # pylint:enable=protected-access
1193
1194
1195def _maybe_add_serialized_attributes(layer, metadata):
1196  # Store attributes revived from SerializedAttributes in a un-tracked
1197  # dictionary. The attributes are the ones listed in CommonEndpoints or
1198  # "keras_api" for keras-specific attributes.
1199  if not hasattr(layer, '_serialized_attributes'):
1200    with utils.no_automatic_dependency_tracking_scope(layer):
1201      layer._serialized_attributes = {'metadata': metadata}  # pylint: disable=protected-access
1202
1203
1204def _get_keras_attr(layer):
1205  return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR,
1206                                                          None)
1207