• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""A `Network` is way to compose layers: the topological form of a `Model`.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import copy
24import itertools
25import json
26import os
27
28import numpy as np
29import six
30from six.moves import zip  # pylint: disable=redefined-builtin
31
32from tensorflow.python.eager import context
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import errors
35from tensorflow.python.framework import errors_impl
36from tensorflow.python.framework import func_graph
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.keras import backend
40from tensorflow.python.keras.engine import base_layer
41from tensorflow.python.keras.engine import base_layer_utils
42from tensorflow.python.keras.engine import input_layer as input_layer_module
43from tensorflow.python.keras.engine import node as node_module
44from tensorflow.python.keras.engine import training_utils
45from tensorflow.python.keras.saving import hdf5_format
46from tensorflow.python.keras.saving import save
47from tensorflow.python.keras.saving.saved_model import network_serialization
48from tensorflow.python.keras.utils import generic_utils
49from tensorflow.python.keras.utils import layer_utils
50from tensorflow.python.keras.utils import tf_utils
51from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
52from tensorflow.python.ops.ragged import ragged_tensor
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.training import checkpoint_management
55from tensorflow.python.training import py_checkpoint_reader
56from tensorflow.python.training.tracking import base as trackable
57from tensorflow.python.training.tracking import data_structures
58from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
59from tensorflow.python.training.tracking import tracking
60from tensorflow.python.training.tracking import util as trackable_utils
61from tensorflow.python.util import nest
62from tensorflow.python.util import serialization
63from tensorflow.python.util import tf_inspect
64
65
66# pylint: disable=g-import-not-at-top
67try:
68  import h5py
69except ImportError:
70  h5py = None
71
72try:
73  import yaml
74except ImportError:
75  yaml = None
76# pylint: enable=g-import-not-at-top
77
78
79class Network(base_layer.Layer):
80  """A `Network` is a composition of layers.
81
82  `Network` is the topological form of a "model". A `Model`
83  is simply a `Network` with added training routines.
84
85  Two types of `Networks` exist: Graph Networks and Subclass Networks. Graph
86  networks are used in the Keras Functional and Sequential APIs. Subclassed
87  networks are used when a user subclasses the `Model` class. In general,
88  more Keras features are supported with Graph Networks than with Subclassed
89  Networks, specifically:
90
91  - Model cloning (`keras.models.clone`)
92  - Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()`
93  - Whole-model saving (`model.save()`)
94
95  A Graph Network can be instantiated by passing two arguments to `__init__`.
96  The first argument is the `keras.Input` Tensors that represent the inputs
97  to the Network. The second argument specifies the output Tensors that
98  represent the outputs of this Network. Both arguments can be a nested
99  structure of Tensors.
100
101  Example:
102
103  ```
104  inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
105  t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
106  outputs = keras.layers.Add()([t, inputs['x2'])
107  network = Network(inputs, outputs)
108  ```
109
110  A Graph Network constructed using the Functional API can also include raw
111  TensorFlow functions, with the exception of functions that create Variables
112  or assign ops.
113
114  Example:
115
116  ```
117  inputs = keras.Input(shape=(10,))
118  x = keras.layers.Dense(1)(inputs)
119  outputs = tf.nn.relu(x)
120  network = Network(inputs, outputs)
121  ```
122
123  Subclassed Networks can be instantiated via `name` and (optional) `dynamic`
124  keyword arguments. Subclassed Networks keep track of their Layers, and their
125  `call` method can be overridden. Subclassed Networks are typically created
126  indirectly, by subclassing the `Model` class.
127
128  Example:
129
130  ```
131  class MyModel(keras.Model):
132    def __init__(self):
133      super(MyModel, self).__init__(name='my_model', dynamic=False)
134
135      self.layer1 = keras.layers.Dense(10, activation='relu')
136
137    def call(self, inputs):
138      return self.layer1(inputs)
139  ```
140
141  Allowed args in `super().__init__`:
142    name: String name of the model.
143    dynamic: (Subclassed models only) Set this to `True` if your model should
144      only be run eagerly, and should not be used to generate a static
145      computation graph. This attribute is automatically set for Functional API
146      models.
147    trainable: Boolean, whether the model's variables should be trainable.
148    dtype: (Subclassed models only) Default dtype of the model's weights (
149      default of `None` means use the type of the first input). This attribute
150      has no effect on Functional API models, which do not have weights of their
151      own.
152  """
153
154  # See tf.Module for the usage of this property.
155  # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to
156  # flatten the key since it is trying to convert Trackable/Layer to a string.
157  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
158      ('_layer_call_argspecs', '_compiled_trainable_state'),
159      base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES
160  ))
161
162  def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
163    # Signature detection
164    if (len(args) == 2 or
165        len(args) == 1 and 'outputs' in kwargs or
166        'inputs' in kwargs and 'outputs' in kwargs):
167      # Graph network
168      self._init_graph_network(*args, **kwargs)
169    else:
170      # Subclassed network
171      self._init_subclassed_network(**kwargs)
172
173    tf_utils.assert_no_legacy_layers(self.layers)
174
175  # Several Network methods have "no_automatic_dependency_tracking"
176  # annotations. Since Network does automatic dependency tracking on attribute
177  # assignment, including for common data structures such as lists, by default
178  # we'd have quite a few empty dependencies which users don't care about (or
179  # would need some way to ignore dependencies automatically, which is confusing
180  # when applied to user code). Some attributes, such as _layers, would cause
181  # structural issues (_layers being the place where Layers assigned to tracked
182  # attributes are stored).
183  #
184  # Aside from these aesthetic and structural issues, useless dependencies on
185  # empty lists shouldn't cause issues; adding or removing them will not break
186  # checkpoints, but may cause "all Python objects matched" assertions to fail
187  # (in which case less strict assertions may be substituted if necessary).
188  @trackable.no_automatic_dependency_tracking
189  def _base_init(self, name=None, **kwargs):
190    # The following are implemented as property functions:
191    # self.trainable_weights
192    # self.non_trainable_weights
193    # self.input_spec
194    # self.losses
195    # self.updates
196
197    generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic',
198                                           'autocast'})
199
200    super(Network, self).__init__(name=name, **kwargs)
201
202    self._is_compiled = False
203
204    # This is True for Sequential networks and Functional networks.
205    self._compute_output_and_mask_jointly = False
206
207    if not hasattr(self, 'optimizer'):
208      # Don't reset optimizer if already set.
209      self.optimizer = None
210
211    self._scope = None  # Never used.
212    self._reuse = None  # Never used.
213    if context.executing_eagerly():
214      self._graph = None
215    else:
216      self._graph = ops.get_default_graph()  # Used in symbolic mode only.
217
218    self._trackable_saver = (
219        trackable_utils.saver_with_op_caching(self))
220
221  @trackable.no_automatic_dependency_tracking
222  def _init_graph_network(self, inputs, outputs, name=None, **kwargs):
223    generic_utils.validate_kwargs(
224        kwargs, {'trainable'},
225        'Functional models may only specify `name` and `trainable` keyword '
226        'arguments during initialization. Got an unexpected argument:')
227    # Normalize and set self.inputs, self.outputs.
228    if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1:
229      inputs = inputs[0]
230    if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1:
231      outputs = outputs[0]
232    self._nested_outputs = outputs
233    self._nested_inputs = inputs
234    self.inputs = nest.flatten(inputs)
235    self.outputs = nest.flatten(outputs)
236
237    if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
238      base_layer_utils.create_keras_history(self._nested_outputs)
239
240    self._base_init(name=name, **kwargs)
241    self._validate_graph_inputs_and_outputs()
242
243    # A Network does not create weights of its own, thus it is already
244    # built.
245    self.built = True
246    self._compute_output_and_mask_jointly = True
247    self._is_graph_network = True
248    # `_expects_training_arg` is True since the `training` argument is always
249    # present in the signature of the `call` method of a graph network.
250    self._expects_training_arg = True
251    self._expects_mask_arg = True
252    # A graph network does not autocast inputs, as its layers will cast them
253    # instead.
254    self._autocast = False
255
256    self._input_layers = []
257    self._output_layers = []
258    self._input_coordinates = []
259    self._output_coordinates = []
260
261    self._supports_ragged_inputs = None
262
263    # This is for performance optimization when calling the Network on new
264    # inputs. Every time the Network is called on a set on input tensors,
265    # we compute the output tensors, output masks and output shapes in one pass,
266    # then cache them here. When any of these outputs is queried later, we
267    # retrieve it from there instead of recomputing it.
268    self._output_mask_cache = {}
269    self._output_tensor_cache = {}
270    self._output_shape_cache = {}
271
272    # Build self._output_layers:
273    for x in self.outputs:
274      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
275      self._output_layers.append(layer)
276      self._output_coordinates.append((layer, node_index, tensor_index))
277
278    # Build self._input_layers:
279    for x in self.inputs:
280      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
281      # It's supposed to be an input layer, so only one node
282      # and one tensor output.
283      assert node_index == 0
284      assert tensor_index == 0
285      self._input_layers.append(layer)
286      self._input_coordinates.append((layer, node_index, tensor_index))
287
288    # Keep track of the network's nodes and layers.
289    nodes, nodes_by_depth, layers, _ = _map_graph_network(
290        self.inputs, self.outputs)
291    self._network_nodes = nodes
292    self._nodes_by_depth = nodes_by_depth
293    self._layers = layers
294    self._layer_call_argspecs = {}
295    for layer in self._layers:
296      self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
297      layer._attribute_sentinel.add_parent(self._attribute_sentinel)
298
299    # Create the node linking internal inputs to internal outputs.
300    node_module.Node(
301        outbound_layer=self,
302        inbound_layers=[],
303        node_indices=[],
304        tensor_indices=[],
305        input_tensors=self._nested_inputs,
306        output_tensors=self._nested_outputs)
307
308    # Build self.input_names and self.output_names.
309    self._set_output_names()
310    self.input_names = []
311    self._feed_input_names = []
312    self._feed_inputs = []
313    self._feed_input_shapes = []
314    for layer in self._input_layers:
315      self.input_names.append(layer.name)
316      if layer.is_placeholder:
317        self._feed_input_names.append(layer.name)
318        # Use batch_input_shape here because non-eager composite tensors may not
319        # have a shape attribute that's meaningful (sparse, for instance, has
320        # a tensor that's non-constant and needs to be fed). This means that
321        # input layers that create placeholders will need to have the
322        # batch_input_shape attr to allow for input shape validation.
323        self._feed_input_shapes.append(layer._batch_input_shape)
324        self._feed_inputs.append(layer.input)
325
326    self._compute_tensor_usage_count()
327
328  def _set_output_names(self):
329    """Assigns unique names to the Network's outputs.
330
331    Output layers with multiple output tensors would otherwise lead to duplicate
332    names in self.output_names.
333    """
334    uniquified = []
335    output_names = set()
336    prefix_count = {}
337    for layer in self._output_layers:
338      proposal = layer.name
339      while proposal in output_names:
340        existing_count = prefix_count.get(layer.name, 1)
341        proposal = '{}_{}'.format(layer.name, existing_count)
342        prefix_count[layer.name] = existing_count + 1
343      output_names.add(proposal)
344      uniquified.append(proposal)
345    self.output_names = uniquified
346
347  @trackable.no_automatic_dependency_tracking
348  def _init_subclassed_network(self, name=None, **kwargs):
349    self._base_init(name=name, **kwargs)
350    self._is_graph_network = False
351    self._init_call_fn_args()
352    self._autocast = kwargs.get('autocast',
353                                base_layer_utils.v2_dtype_behavior_enabled())
354    self._supports_ragged_inputs = None
355    self.outputs = []
356    self.inputs = []
357    self.built = False
358
359  @property
360  @trackable_layer_utils.cache_recursive_attribute('dynamic')
361  def dynamic(self):
362    if self._is_graph_network:
363      return any(layer.dynamic for layer in self.layers)
364    return self._dynamic or any(layer.dynamic for layer in self.layers)
365
366  @property
367  def _layer_checkpoint_dependencies(self):
368    """Dictionary of layer dependencies to be included in the checkpoint."""
369    # Use getattr because this function can be called from __setattr__, at which
370    # point the _is_graph_network attribute has not been created.
371    if (not getattr(self, '_is_graph_network', False) and
372        base_layer_utils.is_subclassed(self)):
373      return {}  # Only add layer dependencies for graph networks
374
375    weight_layer_index = 0
376
377    dependencies = {}
378    for layer_index, layer in enumerate(self.layers):
379      try:
380        if layer.weights:
381          # Keep a separate index for layers which have weights. This allows
382          # users to insert Layers without weights anywhere in the network
383          # without breaking checkpoints.
384          dependencies['layer_with_weights-%d' % weight_layer_index] = layer
385          weight_layer_index += 1
386      except ValueError:
387        # The layer might have weights, but may not be built yet. We just treat
388        # it as layer without weight.
389        pass
390
391      # Even if it doesn't have weights, we should still track everything in
392      # case it has/will have Trackable dependencies.
393      dependencies['layer-%d' % layer_index] = layer
394    return dependencies
395
396  @property
397  def _checkpoint_dependencies(self):
398    dependencies = [
399        trackable.TrackableReference(name=name, ref=layer)
400        for name, layer in self._layer_checkpoint_dependencies.items()]
401    dependencies.extend(super(Network, self)._checkpoint_dependencies)
402    return dependencies
403
404  def _lookup_dependency(self, name):
405    layer_dependencies = self._layer_checkpoint_dependencies
406    if name in layer_dependencies:
407      return layer_dependencies[name]
408    return super(Network, self)._lookup_dependency(name)
409
410  def _handle_deferred_layer_dependencies(self, layers):
411    """Handles layer checkpoint dependencies that are added after init."""
412    layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
413    layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
414    for layer in layers:
415      if layer in layer_to_name:
416        self._handle_deferred_dependencies(name=layer_to_name[layer],
417                                           trackable=layer)
418
419  def __setattr__(self, name, value):
420    if not getattr(self, '_self_setattr_tracking', True):
421      super(Network, self).__setattr__(name, value)
422      return
423
424    if all(
425        isinstance(v, (base_layer.Layer,
426                       data_structures.TrackableDataStructure)) or
427        trackable_layer_utils.has_weights(v) for v in nest.flatten(value)):
428      try:
429        self._is_graph_network
430      except AttributeError:
431        # six.raise_from supresses the original AttributeError from being raised
432        six.raise_from(
433            RuntimeError('It looks like you are subclassing `Model` and you '
434                         'forgot to call `super(YourClass, self).__init__()`.'
435                         ' Always start with this line.'), None)
436
437    super(Network, self).__setattr__(name, value)
438
439    # Keep track of metric instance created in subclassed model/layer.
440    # We do this so that we can maintain the correct order of metrics by adding
441    # the instance to the `metrics` list as soon as it is created.
442    from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
443    if isinstance(value, metrics_module.Metric):
444      self._metrics.append(value)
445
446  @property
447  @trackable_layer_utils.cache_recursive_attribute('stateful')
448  def stateful(self):
449    return any(getattr(layer, 'stateful', False) for layer in self.layers)
450
451  def reset_states(self):
452    for layer in self.layers:
453      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
454        layer.reset_states()
455
456  @property
457  def state_updates(self):
458    """Returns the `updates` from all layers that are stateful.
459
460    This is useful for separating training updates and
461    state updates, e.g. when we need to update a layer's internal state
462    during prediction.
463
464    Returns:
465        A list of update ops.
466    """
467    state_updates = []
468    for layer in self.layers:
469      if getattr(layer, 'stateful', False):
470        if hasattr(layer, 'updates'):
471          state_updates += layer.updates
472    return state_updates
473
474  @property
475  def weights(self):
476    """Returns the list of all layer variables/weights.
477
478    Returns:
479      A list of variables.
480    """
481    return self._dedup_weights(self._undeduplicated_weights)
482
483  @property
484  def _undeduplicated_weights(self):
485    """Returns the undeduplicated list of all layer variables/weights."""
486    self._assert_weights_created()
487    weights = []
488    for layer in self._layers:
489      weights += layer.weights
490    weights += (self._trainable_weights + self._non_trainable_weights)
491    return weights
492
493  @property
494  @tracking.cached_per_instance
495  def _should_compute_mask(self):
496    return self._is_graph_network and super(Network, self)._should_compute_mask
497
498  def compute_mask(self, inputs, mask):
499    if not self._is_graph_network:
500      return None
501
502    # TODO(omalleyt): b/123540974 This function is not really safe to call
503    # by itself because it will duplicate any updates and losses in graph
504    # mode by `call`ing the Layers again.
505    output_tensors = self._run_internal_graph(inputs, mask=mask)
506    return nest.map_structure(lambda t: t._keras_mask, output_tensors)
507
508  @property
509  def layers(self):
510    return list(
511        trackable_layer_utils.filter_empty_layer_containers(self._layers))
512
513  def get_layer(self, name=None, index=None):
514    """Retrieves a layer based on either its name (unique) or index.
515
516    If `name` and `index` are both provided, `index` will take precedence.
517    Indices are based on order of horizontal graph traversal (bottom-up).
518
519    Arguments:
520        name: String, name of layer.
521        index: Integer, index of layer.
522
523    Returns:
524        A layer instance.
525
526    Raises:
527        ValueError: In case of invalid layer name or index.
528    """
529    # TODO(fchollet): We could build a dictionary based on layer names
530    # since they are constant, but we have not done that yet.
531    if index is not None:
532      if len(self.layers) <= index:
533        raise ValueError('Was asked to retrieve layer at index ' + str(index) +
534                         ' but model only has ' + str(len(self.layers)) +
535                         ' layers.')
536      else:
537        return self.layers[index]
538    else:
539      if not name:
540        raise ValueError('Provide either a layer name or layer index.')
541    for layer in self.layers:
542      if layer.name == name:
543        return layer
544    raise ValueError('No such layer: ' + name)
545
546  @property
547  def trainable_weights(self):
548    self._assert_weights_created()
549    return self._dedup_weights(
550        trackable_layer_utils.gather_trainable_weights(
551            trainable=self.trainable,
552            sub_layers=self._layers,
553            extra_variables=self._trainable_weights))
554
555  @property
556  def non_trainable_weights(self):
557    self._assert_weights_created()
558    return self._dedup_weights(
559        trackable_layer_utils.gather_non_trainable_weights(
560            trainable=self.trainable,
561            sub_layers=self._layers,
562            extra_variables=self._non_trainable_weights +
563            self._trainable_weights))
564
565  @property
566  def input_spec(self):
567    """Gets the network's input specs.
568
569    Returns:
570        A list of `InputSpec` instances (one per input to the model)
571            or a single instance if the model has only one input.
572    """
573    # If subclassed model, can't assume anything.
574    if not self._is_graph_network:
575      return None
576
577    specs = []
578    for layer in self._input_layers:
579      if layer.input_spec is None:
580        specs.append(None)
581      else:
582        if not isinstance(layer.input_spec, list):
583          raise TypeError('Layer ' + layer.name +
584                          ' has an input_spec attribute that '
585                          'is not a list. We expect a list. '
586                          'Found input_spec = ' + str(layer.input_spec))
587        specs += layer.input_spec
588    if len(specs) == 1:
589      return specs[0]
590    return specs
591
592  @base_layer_utils.default
593  def build(self, input_shape):
594    """Builds the model based on input shapes received.
595
596    This is to be used for subclassed models, which do not know at instantiation
597    time what their inputs look like.
598
599    This method only exists for users who want to call `model.build()` in a
600    standalone way (as a substitute for calling the model on real data to
601    build it). It will never be called by the framework (and thus it will
602    never throw unexpected errors in an unrelated workflow).
603
604    Args:
605     input_shape: Single tuple, TensorShape, or list of shapes, where shapes
606         are tuples, integers, or TensorShapes.
607
608    Raises:
609      ValueError:
610        1. In case of invalid user-provided data (not of type tuple,
611           list, or TensorShape).
612        2. If the model requires call arguments that are agnostic
613           to the input shapes (positional or kwarg in call signature).
614        3. If not all layers were properly built.
615        4. If float type inputs are not supported within the layers.
616
617      In each of these cases, the user should build their model by calling it
618      on real tensor data.
619    """
620    if self._is_graph_network:
621      self.built = True
622      return
623
624    # If subclass network
625    if input_shape is None:
626      raise ValueError('Input shape must be defined when calling build on a '
627                       'model subclass network.')
628    valid_types = (tuple, list, tensor_shape.TensorShape)
629    if not isinstance(input_shape, valid_types):
630      raise ValueError('Specified input shape is not one of the valid types. '
631                       'Please specify a batch input shape of type tuple or '
632                       'list of input shapes. User provided '
633                       'input type: {}'.format(type(input_shape)))
634
635    if input_shape and not self.inputs:
636      # We create placeholders for the `None`s in the shape and build the model
637      # in a Graph. Since tf.Variable is compatible with both eager execution
638      # and graph building, the variables created after building the model in
639      # a Graph are still valid when executing eagerly.
640      if context.executing_eagerly():
641        graph = func_graph.FuncGraph('build_graph')
642      else:
643        graph = backend.get_graph()
644      with graph.as_default():
645        if isinstance(input_shape, list):
646          x = [base_layer_utils.generate_placeholders_from_shape(shape)
647               for shape in input_shape]
648        else:
649          x = base_layer_utils.generate_placeholders_from_shape(input_shape)
650
651        kwargs = {}
652        call_signature = self._call_full_argspec
653        call_args = call_signature.args
654        # Exclude `self`, `inputs`, and any argument with a default value.
655        if len(call_args) > 2:
656          if call_signature.defaults:
657            call_args = call_args[2:-len(call_signature.defaults)]
658          else:
659            call_args = call_args[2:]
660          for arg in call_args:
661            if arg == 'training':
662              # Case where `training` is a positional arg with no default.
663              kwargs['training'] = False
664            else:
665              # Has invalid call signature with unknown positional arguments.
666              raise ValueError(
667                  'Currently, you cannot build your model if it has '
668                  'positional or keyword arguments that are not '
669                  'inputs to the model, but are required for its '
670                  '`call` method. Instead, in order to instantiate '
671                  'and build your model, `call` your model on real '
672                  'tensor data with all expected call arguments.')
673        elif len(call_args) < 2:
674          # Signature without `inputs`.
675          raise ValueError('You can only call `build` on a model if its `call` '
676                           'method accepts an `inputs` argument.')
677        try:
678          self.call(x, **kwargs)
679        except (errors.InvalidArgumentError, TypeError):
680          raise ValueError('You cannot build your model by calling `build` '
681                           'if your layers do not support float type inputs. '
682                           'Instead, in order to instantiate and build your '
683                           'model, `call` your model on real tensor data (of '
684                           'the correct dtype).')
685
686    self.built = True
687
688  def call(self, inputs, training=None, mask=None):
689    """Calls the model on new inputs.
690
691    In this case `call` just reapplies
692    all ops in the graph to the new inputs
693    (e.g. build a new computational graph from the provided inputs).
694
695    Arguments:
696        inputs: A tensor or list of tensors.
697        training: Boolean or boolean scalar tensor, indicating whether to run
698          the `Network` in training mode or inference mode.
699        mask: A mask or list of masks. A mask can be
700            either a tensor or None (no mask).
701
702    Returns:
703        A tensor if there is a single output, or
704        a list of tensors if there are more than one outputs.
705    """
706    if not self._is_graph_network:
707      raise NotImplementedError('When subclassing the `Model` class, you should'
708                                ' implement a `call` method.')
709
710    return self._run_internal_graph(
711        inputs, training=training, mask=mask,
712        convert_kwargs_to_constants=base_layer_utils.call_context().saving)
713
714  def compute_output_shape(self, input_shape):
715    if not self._is_graph_network:
716      return super(Network, self).compute_output_shape(input_shape)
717
718    # Convert any shapes in tuple format to TensorShapes.
719    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
720
721    if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)):
722      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
723                       ': model has ' + str(len(self._input_layers)) +
724                       ' tensor inputs.')
725
726    cache_key = generic_utils.object_list_uid(input_shape)
727    if cache_key in self._output_shape_cache:
728      # Cache hit. Return shapes as TensorShapes.
729      return self._output_shape_cache[cache_key]
730
731    layers_to_output_shapes = {}
732    for layer, shape in zip(self._input_layers, nest.flatten(input_shape)):
733      # It's an input layer: then `compute_output_shape` is identity,
734      # and there is only one node and one tensor..
735      shape_key = layer.name + '_0_0'
736      layers_to_output_shapes[shape_key] = shape
737
738    depth_keys = list(self._nodes_by_depth.keys())
739    depth_keys.sort(reverse=True)
740    # Iterate over nodes, by depth level.
741    if len(depth_keys) > 1:
742      for depth in depth_keys:
743        nodes = self._nodes_by_depth[depth]
744        for node in nodes:
745          # This is always a single layer, never a list.
746          layer = node.outbound_layer
747          if layer in self._input_layers:
748            # We've already covered the input layers
749            # a few lines above.
750            continue
751          # Potentially redundant list,
752          # same size as node.input_tensors.
753          layer_input_shapes = []
754          for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound():
755            input_layer_key = inbound_layer.name + '_%s_%s' % (node_id,
756                                                               tensor_id)
757            layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
758          layer_input_shapes = nest.pack_sequence_as(node.inbound_layers,
759                                                     layer_input_shapes)
760          # Layers expect shapes to be tuples for `compute_output_shape`.
761          layer_input_shapes = tf_utils.convert_shapes(
762              layer_input_shapes, to_tuples=True)
763          layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
764          # Convert back to TensorShapes.
765          layer_output_shapes = tf_utils.convert_shapes(
766              layer_output_shapes, to_tuples=False)
767
768          node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
769          for j, shape in enumerate(nest.flatten(layer_output_shapes)):
770            shape_key = layer.name + '_%s_%s' % (node_index, j)
771            layers_to_output_shapes[shape_key] = shape
772
773      # Read final output shapes from layers_to_output_shapes.
774      output_shapes = []
775      for i in range(len(self._output_layers)):
776        layer, node_index, tensor_index = self._output_coordinates[i]
777        shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
778        output_shapes.append(layers_to_output_shapes[shape_key])
779      output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes)
780      # Store in cache.
781      self._output_shape_cache[cache_key] = output_shapes
782
783    # Return shapes as TensorShapes.
784    return output_shapes
785
786  def _run_internal_graph(self, inputs, training=None, mask=None,
787                          convert_kwargs_to_constants=False):
788    """Computes output tensors for new inputs.
789
790    # Note:
791        - Can be run on non-Keras tensors.
792
793    Arguments:
794        inputs: Tensor or nested structure of Tensors.
795        training: Boolean learning phase.
796        mask: (Optional) Tensor or nested structure of Tensors.
797        convert_kwargs_to_constants: Whether to convert Tensor kwargs to
798          constants. This is used when tracing the model call function during
799          saving to ensure that external tensors aren't captured.
800
801    Returns:
802        Two lists: output_tensors, output_masks
803    """
804    # Note: masking support is relevant mainly for Keras.
805    # It cannot be factored out without having the fully reimplement the network
806    # calling logic on the Keras side. We choose to incorporate it in
807    # Network because 1) it may be useful to fully support in tf.layers in
808    # the future and 2) Keras is a major user of Network.  If you don't
809    # use masking, it does not interfere with regular behavior at all and you
810    # can ignore it.
811
812    if isinstance(inputs, dict) and isinstance(self._nested_inputs,
813                                               (list, tuple)):
814      # Backwards compat: Allows passing a dict to a Model constructed with a
815      # list. Matches dict keys to input names.
816      inputs = [
817          inputs[inp._keras_history.layer.name] for inp in self._nested_inputs
818      ]
819    else:
820      inputs = nest.flatten(inputs)
821
822    if mask is None:
823      masks = [None for _ in range(len(inputs))]
824    else:
825      masks = nest.flatten(mask)
826
827    for input_t, mask in zip(inputs, masks):
828      input_t._keras_mask = mask
829
830    # Dictionary mapping reference tensors to computed tensors.
831    tensor_dict = {}
832
833    for x, y in zip(self.inputs, inputs):
834      x_id = str(id(x))
835      tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
836      if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
837        try:
838          y.set_shape(y.shape.merge_with(x.shape))
839        except ValueError:
840          logging.warning(
841              'Model was constructed with shape {} for input {}, but it was '
842              're-called on a Tensor with incompatible shape {}.'
843              .format(x, x.shape, y.shape))
844
845    depth_keys = list(self._nodes_by_depth.keys())
846    depth_keys.sort(reverse=True)
847    # Ignore the InputLayers when computing the graph.
848    depth_keys = depth_keys[1:]
849
850    for depth in depth_keys:
851      nodes = self._nodes_by_depth[depth]
852      for node in nodes:
853        # This is always a single layer, never a list.
854        layer = node.outbound_layer
855
856        if all(
857            str(id(tensor)) in tensor_dict
858            for tensor in nest.flatten(node.input_tensors)):
859
860          # Call layer (reapplying ops to new inputs).
861          computed_tensors = nest.map_structure(
862              lambda t: tensor_dict[str(id(t))].pop(), node.input_tensors)
863
864          # Ensure `training` arg propagation if applicable.
865          kwargs = copy.copy(node.arguments) if node.arguments else {}
866          if convert_kwargs_to_constants:
867            kwargs = _map_tensors_to_constants(kwargs)
868
869          argspec = self._layer_call_argspecs[layer].args
870          if 'training' in argspec:
871            kwargs.setdefault('training', training)
872            if (type(kwargs['training']) is ops.Tensor and  # pylint: disable=unidiomatic-typecheck
873                any([kwargs['training'] is x
874                     for x in backend._GRAPH_LEARNING_PHASES.values()])):
875              kwargs['training'] = training  # Materialize placeholder.
876
877          # Map Keras tensors in kwargs to their computed value.
878          def _map_tensor_if_from_keras_layer(t):
879            if isinstance(t, ops.Tensor) and hasattr(t, '_keras_history'):
880              t_id = str(id(t))
881              return tensor_dict[t_id].pop()
882            return t
883
884          kwargs = nest.map_structure(_map_tensor_if_from_keras_layer, kwargs)
885
886          # Compute outputs.
887          output_tensors = layer(computed_tensors, **kwargs)
888
889          # Update tensor_dict.
890          for x, y in zip(
891              nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
892            x_id = str(id(x))
893            tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id]
894
895    output_tensors = []
896    output_shapes = []
897    for x in self.outputs:
898      assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
899      tensor = tensor_dict[str(id(x))].pop()
900      output_shapes.append(x.shape)
901      output_tensors.append(tensor)
902
903    if output_shapes is not None:
904      input_shapes = [x.shape for x in inputs]
905      cache_key = generic_utils.object_list_uid(input_shapes)
906      self._output_shape_cache[cache_key] = nest.pack_sequence_as(
907          self._nested_outputs, output_shapes)
908
909    output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors)
910    return output_tensors
911
912  def get_config(self):
913    if not self._is_graph_network:
914      raise NotImplementedError
915    return copy.deepcopy(get_network_config(self))
916
917  @classmethod
918  def from_config(cls, config, custom_objects=None):
919    """Instantiates a Model from its config (output of `get_config()`).
920
921    Arguments:
922        config: Model config dictionary.
923        custom_objects: Optional dictionary mapping names
924            (strings) to custom classes or functions to be
925            considered during deserialization.
926
927    Returns:
928        A model instance.
929
930    Raises:
931        ValueError: In case of improperly formatted config dict.
932    """
933    input_tensors, output_tensors, created_layers = reconstruct_from_config(
934        config, custom_objects)
935    model = cls(inputs=input_tensors, outputs=output_tensors,
936                name=config.get('name'))
937    connect_ancillary_layers(model, created_layers)
938    return model
939
940  def save(self,
941           filepath,
942           overwrite=True,
943           include_optimizer=True,
944           save_format=None,
945           signatures=None,
946           options=None):
947    """Saves the model to Tensorflow SavedModel or a single HDF5 file.
948
949    The savefile includes:
950        - The model architecture, allowing to re-instantiate the model.
951        - The model weights.
952        - The state of the optimizer, allowing to resume training
953            exactly where you left off.
954
955    This allows you to save the entirety of the state of a model
956    in a single file.
957
958    Saved models can be reinstantiated via `keras.models.load_model`.
959    The model returned by `load_model` is a compiled model ready to be used
960    (unless the saved model was never compiled in the first place).
961
962    Models built with the Sequential and Functional API can be saved to both the
963    HDF5 and SavedModel formats. Subclassed models can only be saved with the
964    SavedModel format.
965
966    Note that the model weights may have different scoped names after being
967    loaded. Scoped names include the model/layer names, such as
968    "dense_1/kernel:0"`. It is recommended that you use the layer properties to
969     access specific variables, e.g. `model.get_layer("dense_1").kernel`.
970
971    Arguments:
972        filepath: String, path to SavedModel or H5 file to save the model.
973        overwrite: Whether to silently overwrite any existing file at the
974            target location, or provide the user with a manual prompt.
975        include_optimizer: If True, save optimizer's state together.
976        save_format: Either 'tf' or 'h5', indicating whether to save the model
977            to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and
978            'h5' in TF 1.X.
979        signatures: Signatures to save with the SavedModel. Applicable to the
980            'tf' format only. Please see the `signatures` argument in
981            `tf.saved_model.save` for details.
982        options: Optional `tf.saved_model.SaveOptions` object that specifies
983            options for saving to SavedModel.
984
985    Example:
986
987    ```python
988    from keras.models import load_model
989
990    model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
991    del model  # deletes the existing model
992
993    # returns a compiled model
994    # identical to the previous one
995    model = load_model('my_model.h5')
996    ```
997    """
998    save.save_model(self, filepath, overwrite, include_optimizer, save_format,
999                    signatures, options)
1000
1001  def save_weights(self, filepath, overwrite=True, save_format=None):
1002    """Saves all layer weights.
1003
1004    Either saves in HDF5 or in TensorFlow format based on the `save_format`
1005    argument.
1006
1007    When saving in HDF5 format, the weight file has:
1008      - `layer_names` (attribute), a list of strings
1009          (ordered names of model layers).
1010      - For every layer, a `group` named `layer.name`
1011          - For every such layer group, a group attribute `weight_names`,
1012              a list of strings
1013              (ordered names of weights tensor of the layer).
1014          - For every weight in the layer, a dataset
1015              storing the weight value, named after the weight tensor.
1016
1017    When saving in TensorFlow format, all objects referenced by the network are
1018    saved in the same format as `tf.train.Checkpoint`, including any `Layer`
1019    instances or `Optimizer` instances assigned to object attributes. For
1020    networks constructed from inputs and outputs using `tf.keras.Model(inputs,
1021    outputs)`, `Layer` instances used by the network are tracked/saved
1022    automatically. For user-defined classes which inherit from `tf.keras.Model`,
1023    `Layer` instances must be assigned to object attributes, typically in the
1024    constructor. See the documentation of `tf.train.Checkpoint` and
1025    `tf.keras.Model` for details.
1026
1027    While the formats are the same, do not mix `save_weights` and
1028    `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be
1029    loaded using `Model.load_weights`. Checkpoints saved using
1030    `tf.train.Checkpoint.save` should be restored using the corresponding
1031    `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over
1032    `save_weights` for training checkpoints.
1033
1034    The TensorFlow format matches objects and variables by starting at a root
1035    object, `self` for `save_weights`, and greedily matching attribute
1036    names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this
1037    is the `Checkpoint` even if the `Checkpoint` has a model attached. This
1038    means saving a `tf.keras.Model` using `save_weights` and loading into a
1039    `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match
1040    the `Model`'s variables. See the [guide to training
1041    checkpoints](https://www.tensorflow.org/guide/checkpoint) for details
1042    on the TensorFlow format.
1043
1044    Arguments:
1045        filepath: String, path to the file to save the weights to. When saving
1046            in TensorFlow format, this is the prefix used for checkpoint files
1047            (multiple files are generated). Note that the '.h5' suffix causes
1048            weights to be saved in HDF5 format.
1049        overwrite: Whether to silently overwrite any existing file at the
1050            target location, or provide the user with a manual prompt.
1051        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
1052            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
1053            `None` defaults to 'tf'.
1054
1055    Raises:
1056        ImportError: If h5py is not available when attempting to save in HDF5
1057            format.
1058        ValueError: For invalid/unknown format arguments.
1059    """
1060    self._assert_weights_created()
1061    filepath_is_h5 = _is_hdf5_filepath(filepath)
1062    if save_format is None:
1063      if filepath_is_h5:
1064        save_format = 'h5'
1065      else:
1066        save_format = 'tf'
1067    else:
1068      user_format = save_format.lower().strip()
1069      if user_format in ('tensorflow', 'tf'):
1070        save_format = 'tf'
1071      elif user_format in ('hdf5', 'h5', 'keras'):
1072        save_format = 'h5'
1073      else:
1074        raise ValueError(
1075            'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
1076                save_format,))
1077    if save_format == 'tf' and filepath_is_h5:
1078      raise ValueError(
1079          ('save_weights got save_format="tf"/"tensorflow", but the '
1080           'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
1081           'when saving in TensorFlow format.')
1082          % filepath)
1083
1084    if save_format == 'h5' and h5py is None:
1085      raise ImportError(
1086          '`save_weights` requires h5py when saving in hdf5.')
1087    if save_format == 'tf':
1088      check_filepath = filepath + '.index'
1089    else:
1090      check_filepath = filepath
1091    # If file exists and should not be overwritten:
1092    if not overwrite and os.path.isfile(check_filepath):
1093      proceed = ask_to_proceed_with_overwrite(check_filepath)
1094      if not proceed:
1095        return
1096    if save_format == 'h5':
1097      with h5py.File(filepath, 'w') as f:
1098        hdf5_format.save_weights_to_hdf5_group(f, self.layers)
1099    else:
1100      if context.executing_eagerly():
1101        session = None
1102      else:
1103        session = backend.get_session()
1104      optimizer = getattr(self, 'optimizer', None)
1105      if (optimizer
1106          and not isinstance(optimizer, trackable.Trackable)):
1107        logging.warning(
1108            ('This model was compiled with a Keras optimizer (%s) but is being '
1109             'saved in TensorFlow format with `save_weights`. The model\'s '
1110             'weights will be saved, but unlike with TensorFlow optimizers in '
1111             'the TensorFlow format the optimizer\'s state will not be '
1112             'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
1113            % (optimizer,))
1114      self._trackable_saver.save(filepath, session=session)
1115      # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
1116      checkpoint_management.update_checkpoint_state_internal(
1117          save_dir=os.path.dirname(filepath),
1118          model_checkpoint_path=filepath,
1119          save_relative_paths=True,
1120          all_model_checkpoint_paths=[filepath])
1121
1122  def load_weights(self, filepath, by_name=False, skip_mismatch=False):
1123    """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
1124
1125    If `by_name` is False weights are loaded based on the network's
1126    topology. This means the architecture should be the same as when the weights
1127    were saved.  Note that layers that don't have weights are not taken into
1128    account in the topological ordering, so adding or removing layers is fine as
1129    long as they don't have weights.
1130
1131    If `by_name` is True, weights are loaded into layers only if they share the
1132    same name. This is useful for fine-tuning or transfer-learning models where
1133    some of the layers have changed.
1134
1135    Only topological loading (`by_name=False`) is supported when loading weights
1136    from the TensorFlow format. Note that topological loading differs slightly
1137    between TensorFlow and HDF5 formats for user-defined classes inheriting from
1138    `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
1139    TensorFlow format loads based on the object-local names of attributes to
1140    which layers are assigned in the `Model`'s constructor.
1141
1142    Arguments:
1143        filepath: String, path to the weights file to load. For weight files in
1144            TensorFlow format, this is the file prefix (the same as was passed
1145            to `save_weights`).
1146        by_name: Boolean, whether to load weights by name or by topological
1147            order. Only topological loading is supported for weight files in
1148            TensorFlow format.
1149        skip_mismatch: Boolean, whether to skip loading of layers where there is
1150            a mismatch in the number of weights, or a mismatch in the shape of
1151            the weight (only valid when `by_name=True`).
1152
1153    Returns:
1154        When loading a weight file in TensorFlow format, returns the same status
1155        object as `tf.train.Checkpoint.restore`. When graph building, restore
1156        ops are run automatically as soon as the network is built (on first call
1157        for user-defined classes inheriting from `Model`, immediately if it is
1158        already built).
1159
1160        When loading weights in HDF5 format, returns `None`.
1161
1162    Raises:
1163        ImportError: If h5py is not available and the weight file is in HDF5
1164            format.
1165        ValueError: If `skip_mismatch` is set to `True` when `by_name` is
1166          `False`.
1167    """
1168
1169    if skip_mismatch and not by_name:
1170      raise ValueError(
1171          'When calling model.load_weights, skip_mismatch can only be set to '
1172          'True when by_name is True.')
1173
1174    if _is_hdf5_filepath(filepath):
1175      save_format = 'h5'
1176    else:
1177      try:
1178        py_checkpoint_reader.NewCheckpointReader(filepath)
1179        save_format = 'tf'
1180      except errors_impl.DataLossError:
1181        # The checkpoint is not readable in TensorFlow format. Try HDF5.
1182        save_format = 'h5'
1183    if save_format == 'tf':
1184      status = self._trackable_saver.restore(filepath)
1185      if by_name:
1186        raise NotImplementedError(
1187            'Weights may only be loaded based on topology into Models when '
1188            'loading TensorFlow-formatted weights (got by_name=True to '
1189            'load_weights).')
1190      if not context.executing_eagerly():
1191        session = backend.get_session()
1192        # Restore existing variables (if any) immediately, and set up a
1193        # streaming restore for any variables created in the future.
1194        trackable_utils.streaming_restore(status=status, session=session)
1195      status.assert_nontrivial_match()
1196      return status
1197    if h5py is None:
1198      raise ImportError(
1199          '`load_weights` requires h5py when loading weights from HDF5.')
1200    if self._is_graph_network and not self.built:
1201      raise NotImplementedError(
1202          'Unable to load weights saved in HDF5 format into a subclassed '
1203          'Model which has not created its variables yet. Call the Model '
1204          'first, then load the weights.')
1205    self._assert_weights_created()
1206    with h5py.File(filepath, 'r') as f:
1207      if 'layer_names' not in f.attrs and 'model_weights' in f:
1208        f = f['model_weights']
1209      if by_name:
1210        hdf5_format.load_weights_from_hdf5_group_by_name(
1211            f, self.layers, skip_mismatch=skip_mismatch)
1212      else:
1213        hdf5_format.load_weights_from_hdf5_group(f, self.layers)
1214
1215  def _updated_config(self):
1216    """Util shared between different serialization methods.
1217
1218    Returns:
1219        Model config with Keras version information added.
1220    """
1221    from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
1222
1223    config = self.get_config()
1224    model_config = {
1225        'class_name': self.__class__.__name__,
1226        'config': config,
1227        'keras_version': keras_version,
1228        'backend': backend.backend()
1229    }
1230    return model_config
1231
1232  def to_json(self, **kwargs):
1233    """Returns a JSON string containing the network configuration.
1234
1235    To load a network from a JSON save file, use
1236    `keras.models.model_from_json(json_string, custom_objects={})`.
1237
1238    Arguments:
1239        **kwargs: Additional keyword arguments
1240            to be passed to `json.dumps()`.
1241
1242    Returns:
1243        A JSON string.
1244    """
1245    model_config = self._updated_config()
1246    return json.dumps(
1247        model_config, default=serialization.get_json_type, **kwargs)
1248
1249  def to_yaml(self, **kwargs):
1250    """Returns a yaml string containing the network configuration.
1251
1252    To load a network from a yaml save file, use
1253    `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
1254
1255    `custom_objects` should be a dictionary mapping
1256    the names of custom losses / layers / etc to the corresponding
1257    functions / classes.
1258
1259    Arguments:
1260        **kwargs: Additional keyword arguments
1261            to be passed to `yaml.dump()`.
1262
1263    Returns:
1264        A YAML string.
1265
1266    Raises:
1267        ImportError: if yaml module is not found.
1268    """
1269    if yaml is None:
1270      raise ImportError(
1271          'Requires yaml module installed (`pip install pyyaml`).')
1272    return yaml.dump(self._updated_config(), **kwargs)
1273
1274  def summary(self, line_length=None, positions=None, print_fn=None):
1275    """Prints a string summary of the network.
1276
1277    Arguments:
1278        line_length: Total length of printed lines
1279            (e.g. set this to adapt the display to different
1280            terminal window sizes).
1281        positions: Relative or absolute positions of log elements
1282            in each line. If not provided,
1283            defaults to `[.33, .55, .67, 1.]`.
1284        print_fn: Print function to use. Defaults to `print`.
1285            It will be called on each line of the summary.
1286            You can set it to a custom function
1287            in order to capture the string summary.
1288
1289    Raises:
1290        ValueError: if `summary()` is called before the model is built.
1291    """
1292    if not self.built:
1293      raise ValueError('This model has not yet been built. '
1294                       'Build the model first by calling `build()` or calling '
1295                       '`fit()` with some data, or specify '
1296                       'an `input_shape` argument in the first layer(s) for '
1297                       'automatic build.')
1298    layer_utils.print_summary(self,
1299                              line_length=line_length,
1300                              positions=positions,
1301                              print_fn=print_fn)
1302
1303  def _validate_graph_inputs_and_outputs(self):
1304    """Validates the inputs and outputs of a Graph Network."""
1305    # Check for redundancy in inputs.
1306    if len({id(i) for i in self.inputs}) != len(self.inputs):
1307      raise ValueError('The list of inputs passed to the model '
1308                       'is redundant. '
1309                       'All inputs should only appear once.'
1310                       ' Found: ' + str(self.inputs))
1311
1312    for x in self.inputs:
1313      # Check that x has appropriate `_keras_history` metadata.
1314      if not hasattr(x, '_keras_history'):
1315        cls_name = self.__class__.__name__
1316        raise ValueError('Input tensors to a ' + cls_name + ' ' +
1317                         'must come from `tf.keras.Input`. '
1318                         'Received: ' + str(x) +
1319                         ' (missing previous layer metadata).')
1320      # Check that x is an input tensor.
1321      # pylint: disable=protected-access
1322      layer = x._keras_history.layer
1323      if len(layer._inbound_nodes) > 1 or (
1324          layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers):
1325        cls_name = self.__class__.__name__
1326        logging.warning(cls_name + ' inputs must come from '
1327                        '`tf.keras.Input` (thus holding past layer metadata), '
1328                        'they cannot be the output of '
1329                        'a previous non-Input layer. '
1330                        'Here, a tensor specified as '
1331                        'input to "' + self.name + '" was not an Input tensor, '
1332                        'it was generated by layer ' + layer.name + '.\n'
1333                        'Note that input tensors are '
1334                        'instantiated via `tensor = tf.keras.Input(shape)`.\n'
1335                        'The tensor that caused the issue was: ' + str(x.name))
1336      if isinstance(x, ragged_tensor.RaggedTensor):
1337        self._supports_ragged_inputs = True
1338
1339    # Check compatibility of batch sizes of Input Layers.
1340    input_batch_sizes = [
1341        training_utils.get_static_batch_size(x._keras_history.layer)
1342        for x in self.inputs
1343    ]
1344    consistent_batch_size = None
1345    for batch_size in input_batch_sizes:
1346      if batch_size is not None:
1347        if (consistent_batch_size is not None and
1348            batch_size != consistent_batch_size):
1349          raise ValueError('The specified batch sizes of the Input Layers'
1350                           ' are incompatible. Found batch sizes: {}'.format(
1351                               input_batch_sizes))
1352        consistent_batch_size = batch_size
1353
1354    for x in self.outputs:
1355      if not hasattr(x, '_keras_history'):
1356        cls_name = self.__class__.__name__
1357        raise ValueError('Output tensors to a ' + cls_name + ' must be '
1358                         'the output of a TensorFlow `Layer` '
1359                         '(thus holding past layer metadata). Found: ' + str(x))
1360
1361  def _insert_layers(self, layers, relevant_nodes=None):
1362    """Inserts Layers into the Network after Network creation.
1363
1364    This is only valid for Keras Graph Networks.  Layers added via this function
1365    will be included in the `call` computation and `get_config` of this Network.
1366    They will not be added to the Network's outputs.
1367
1368
1369    Arguments:
1370      layers: Arbitrary nested structure of Layers. Layers must be reachable
1371        from one or more of the `keras.Input` Tensors that correspond to this
1372        Network's inputs.
1373      relevant_nodes: Nodes from the Layers that should be considered part of
1374        this Network. If `None`, all Nodes will be considered part of this
1375        Network.
1376
1377    Raises:
1378      ValueError: If the layers depend on `Input`s not found in this Model.
1379    """
1380    layers = nest.flatten(layers)
1381    tf_utils.assert_no_legacy_layers(layers)
1382    node_to_depth = {}
1383    for depth, nodes in self._nodes_by_depth.items():
1384      node_to_depth.update({node: depth for node in nodes})
1385    # The nodes of these Layers that are relevant to this Network. If not
1386    # provided, assume all Nodes are relevant
1387    if not relevant_nodes:
1388      relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers])
1389    network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
1390
1391    def _get_min_depth(node):
1392      """Gets the minimum depth at which node can be computed."""
1393      min_depth = 0
1394      for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True):
1395        inbound_node = layer._inbound_nodes[node_id]
1396        if inbound_node in node_to_depth:
1397          min_depth = min(min_depth, node_to_depth[inbound_node])
1398        elif inbound_node not in network_nodes:
1399          continue
1400        else:
1401          # Previous relevant nodes haven't been processed yet.
1402          return None
1403      # New node is one shallower than its shallowest input.
1404      return min_depth - 1
1405
1406    # Insert nodes into `_nodes_by_depth` and other node attrs.
1407    unprocessed_nodes = copy.copy(relevant_nodes)
1408    i = 0
1409    while unprocessed_nodes:
1410      i += 1
1411      # Do a sanity check. This can occur if `Input`s from outside this Model
1412      # are being relied on.
1413      if i > 10000:
1414        raise ValueError('Layers could not be added due to missing '
1415                         'dependencies.')
1416
1417      node = unprocessed_nodes.pop(0)
1418      depth = _get_min_depth(node)
1419      if depth is None:  # Defer until inbound nodes are processed.
1420        unprocessed_nodes.append(node)
1421        continue
1422      node_key = _make_node_key(node.outbound_layer.name,
1423                                node.outbound_layer._inbound_nodes.index(node))
1424      if node_key not in self._network_nodes:
1425        node_to_depth[node] = depth
1426        self._network_nodes.add(node_key)
1427        self._nodes_by_depth[depth].append(node)
1428
1429    # Insert layers and update other layer attrs.
1430    layer_set = set(self._layers)
1431    deferred_layers = []
1432    for layer in layers:
1433      if layer not in layer_set:
1434        self._layers.append(layer)
1435        deferred_layers.append(layer)
1436        self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
1437
1438        # This allows the added layer to broadcast mutations to the current
1439        # layer, which is necessary to ensure cache correctness.
1440        layer._attribute_sentinel.add_parent(self._attribute_sentinel)
1441        layer_set.add(layer)
1442    self._handle_deferred_layer_dependencies(deferred_layers)
1443
1444    self._compute_tensor_usage_count()
1445
1446  def _compute_tensor_usage_count(self):
1447    """Compute the #. of tensor usages for all the output tensors of layers.
1448
1449    The computed tensor usage count is saved as `self._tensor_usage_count`. This
1450    is later used for saving memory in eager computation by releasing
1451    no-longer-needed tensors as early as possible.
1452    """
1453    tensor_usage_count = collections.Counter()
1454    available_tensors = set(str(id(tensor)) for tensor in self.inputs)
1455
1456    depth_keys = list(self._nodes_by_depth.keys())
1457    depth_keys.sort(reverse=True)
1458    depth_keys = depth_keys[1:]
1459
1460    for depth in depth_keys:
1461      for node in self._nodes_by_depth[depth]:
1462        input_tensors = {
1463            str(id(tensor)) for tensor in nest.flatten(node.input_tensors)
1464        }
1465        if input_tensors.issubset(available_tensors):
1466          kwargs = copy.copy(node.arguments) if node.arguments else {}
1467
1468          for tensor in nest.flatten(kwargs):
1469            if isinstance(tensor, ops.Tensor) and hasattr(tensor,
1470                                                          '_keras_history'):
1471              tensor_usage_count[str(id(tensor))] += 1
1472
1473          for tensor in nest.flatten(node.input_tensors):
1474            tensor_usage_count[str(id(tensor))] += 1
1475
1476          for output_tensor in nest.flatten(node.output_tensors):
1477            available_tensors.add(str(id(output_tensor)))
1478
1479    for tensor in self.outputs:
1480      tensor_usage_count[str(id(tensor))] += 1
1481
1482    self._tensor_usage_count = tensor_usage_count
1483
1484  def _assert_weights_created(self):
1485    """Asserts that all the weights for the network have been created.
1486
1487    For a non-dynamic network, the weights must already be created after the
1488    layer has been called. For a dynamic network, the exact list of weights can
1489    never be known for certain since it may change at any time during execution.
1490
1491    We run this check right before accessing weights or getting the Numpy value
1492    for the current weights. Otherwise, if the layer has never been called,
1493    the user would just get an empty list, which is misleading.
1494
1495    Raises:
1496      ValueError: if the weights of the network has not yet been created.
1497    """
1498    if self.dynamic:
1499      return
1500    if (not self._is_graph_network and
1501        'build' in self.__class__.__dict__ and
1502        not self.built):
1503      # For any model that has customized build() method but hasn't
1504      # been invoked yet, this will cover both sequential and subclass model.
1505      raise ValueError('Weights for model %s have not yet been created. '
1506                       'Weights are created when the Model is first called on '
1507                       'inputs or `build()` is called with an `input_shape`.' %
1508                       self.name)
1509
1510  def _graph_network_add_loss(self, symbolic_loss):
1511    new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
1512    # Losses must be keyed on inputs no matter what in order to be supported in
1513    # DistributionStrategy.
1514    add_loss_layer = base_layer.AddLoss(
1515        unconditional=False, dtype=symbolic_loss.dtype)
1516    add_loss_layer(symbolic_loss)
1517    new_nodes.extend(add_loss_layer.inbound_nodes)
1518    new_layers.append(add_loss_layer)
1519    self._insert_layers(new_layers, new_nodes)
1520
1521  def _graph_network_add_metric(self, value, aggregation, name):
1522    new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
1523    add_metric_layer = base_layer.AddMetric(
1524        aggregation, name, dtype=value.dtype)
1525    add_metric_layer(value)
1526    new_nodes.extend(add_metric_layer.inbound_nodes)
1527    new_layers.append(add_metric_layer)
1528    self._insert_layers(new_layers, new_nodes)
1529
1530  @property
1531  def _trackable_saved_model_saver(self):
1532    return network_serialization.NetworkSavedModelSaver(self)
1533
1534
1535def _is_hdf5_filepath(filepath):
1536  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
1537          filepath.endswith('.hdf5'))
1538
1539
1540def _make_node_key(layer_name, node_index):
1541  return layer_name + '_ib-' + str(node_index)
1542
1543
1544def _map_graph_network(inputs, outputs):
1545  """Validates a network's topology and gather its layers and nodes.
1546
1547  Arguments:
1548    inputs: List of input tensors.
1549    outputs: List of outputs tensors.
1550
1551  Returns:
1552    A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
1553    - nodes: list of Node instances.
1554    - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
1555    - layers: list of Layer instances.
1556    - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
1557
1558  Raises:
1559    ValueError: In case the network is not valid (e.g. disconnected graph).
1560  """
1561  # Network_nodes: set of nodes included in the graph of layers
1562  # (not all nodes included in the layers are relevant to the current graph).
1563  network_nodes = set()  # ids of all nodes relevant to the Network
1564  nodes_depths = {}  # dict {node: depth value}
1565  layers_depths = {}  # dict {layer: depth value}
1566  layer_indices = {}  # dict {layer: index in traversal}
1567  nodes_in_decreasing_depth = []
1568
1569  def build_map(tensor,
1570                finished_nodes,
1571                nodes_in_progress,
1572                layer,
1573                node_index,
1574                tensor_index):
1575    """Builds a map of the graph of layers.
1576
1577    This recursively updates the map `layer_indices`,
1578    the list `nodes_in_decreasing_depth` and the set `network_nodes`.
1579
1580    Arguments:
1581        tensor: Some tensor in a graph.
1582        finished_nodes: Set of nodes whose subgraphs have been traversed
1583            completely. Useful to prevent duplicated work.
1584        nodes_in_progress: Set of nodes that are currently active on the
1585            recursion stack. Useful to detect cycles.
1586        layer: Layer from which `tensor` comes from. If not provided,
1587            will be obtained from `tensor._keras_history`.
1588        node_index: Node index from which `tensor` comes from.
1589        tensor_index: Tensor_index from which `tensor` comes from.
1590
1591    Raises:
1592        ValueError: if a cycle is detected.
1593    """
1594    node = layer._inbound_nodes[node_index]  # pylint: disable=protected-access
1595
1596    # Prevent cycles.
1597    if node in nodes_in_progress:
1598      raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
1599                       layer.name + '" is part of a cycle.')
1600
1601    # Don't repeat work for shared subgraphs
1602    if node in finished_nodes:
1603      return
1604
1605    node_key = _make_node_key(layer.name, node_index)
1606    # Update network_nodes.
1607    network_nodes.add(node_key)
1608
1609    # Store the traversal order for layer sorting.
1610    if layer not in layer_indices:
1611      layer_indices[layer] = len(layer_indices)
1612
1613    nodes_in_progress.add(node)
1614
1615    # Propagate to all previous tensors connected to this node.
1616    for layer, node_index, tensor_index, tensor in node.iterate_inbound(
1617        include_arguments=True):
1618      build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index,
1619                tensor_index)
1620
1621    finished_nodes.add(node)
1622    nodes_in_progress.remove(node)
1623    nodes_in_decreasing_depth.append(node)
1624
1625  finished_nodes = set()
1626  nodes_in_progress = set()
1627  for x in outputs:
1628    layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
1629    build_map(x, finished_nodes, nodes_in_progress,
1630              layer=layer,
1631              node_index=node_index,
1632              tensor_index=tensor_index)
1633
1634  for node in reversed(nodes_in_decreasing_depth):
1635    # If the depth is not set, the node has no outbound nodes (depth 0).
1636    depth = nodes_depths.setdefault(node, 0)
1637
1638    # Update the depth of the corresponding layer
1639    previous_depth = layers_depths.get(node.outbound_layer, 0)
1640    # If we've seen this layer before at a higher depth,
1641    # we should use that depth instead of the node depth.
1642    # This is necessary for shared layers that have inputs at different
1643    # depth levels in the graph.
1644    depth = max(depth, previous_depth)
1645    layers_depths[node.outbound_layer] = depth
1646    nodes_depths[node] = depth
1647
1648    # Update the depth of inbound nodes.
1649    # The "depth" of a node is the max of the depths
1650    # of all nodes it is connected to + 1.
1651    for node_dep in node._get_all_node_dependencies():
1652      previous_depth = nodes_depths.get(node_dep, 0)
1653      nodes_depths[node_dep] = max(depth + 1, previous_depth)
1654
1655  # Handle inputs that are not connected to outputs.
1656  # We do not error out here because the inputs may be used to compute losses
1657  # and metrics.
1658  for input_t in inputs:
1659    input_layer = input_t._keras_history[0]
1660    if input_layer not in layers_depths:
1661      layers_depths[input_layer] = 0
1662      layer_indices[input_layer] = -1
1663      nodes_depths[input_layer._inbound_nodes[0]] = 0
1664      network_nodes.add(_make_node_key(input_layer.name, 0))
1665
1666  # Build a dict {depth: list of nodes with this depth}
1667  nodes_by_depth = collections.defaultdict(list)
1668  for node, depth in nodes_depths.items():
1669    nodes_by_depth[depth].append(node)
1670
1671  # Build a dict {depth: list of layers with this depth}
1672  layers_by_depth = collections.defaultdict(list)
1673  for layer, depth in layers_depths.items():
1674    layers_by_depth[depth].append(layer)
1675
1676  # Get sorted list of layer depths.
1677  depth_keys = list(layers_by_depth.keys())
1678  depth_keys.sort(reverse=True)
1679
1680  # Set self.layers ordered by depth.
1681  layers = []
1682  for depth in depth_keys:
1683    layers_for_depth = layers_by_depth[depth]
1684    # Network.layers needs to have a deterministic order:
1685    # here we order them by traversal order.
1686    layers_for_depth.sort(key=lambda x: layer_indices[x])
1687    layers.extend(layers_for_depth)
1688
1689  # Get sorted list of node depths.
1690  depth_keys = list(nodes_by_depth.keys())
1691  depth_keys.sort(reverse=True)
1692
1693  # Check that all tensors required are computable.
1694  # computable_tensors: all tensors in the graph
1695  # that can be computed from the inputs provided.
1696  computable_tensors = set()
1697  for x in inputs:
1698    computable_tensors.add(id(x))
1699
1700  layers_with_complete_input = []  # To provide a better error msg.
1701  for depth in depth_keys:
1702    for node in nodes_by_depth[depth]:
1703      layer = node.outbound_layer
1704      if layer:
1705        for x in nest.flatten(node.input_tensors):
1706          if id(x) not in computable_tensors:
1707            raise ValueError('Graph disconnected: '
1708                             'cannot obtain value for tensor ' + str(x) +
1709                             ' at layer "' + layer.name + '". '
1710                             'The following previous layers '
1711                             'were accessed without issue: ' +
1712                             str(layers_with_complete_input))
1713        for x in nest.flatten(node.output_tensors):
1714          computable_tensors.add(id(x))
1715        layers_with_complete_input.append(layer.name)
1716
1717  # Ensure name unicity, which will be crucial for serialization
1718  # (since serialized nodes refer to layers by their name).
1719  all_names = [layer.name for layer in layers]
1720  for name in all_names:
1721    if all_names.count(name) != 1:
1722      raise ValueError('The name "' + name + '" is used ' +
1723                       str(all_names.count(name)) + ' times in the model. '
1724                       'All layer names should be unique.')
1725  return network_nodes, nodes_by_depth, layers, layers_by_depth
1726
1727
1728def _map_subgraph_network(inputs, outputs):
1729  """Returns the nodes and layers in the topology from `inputs` to `outputs`.
1730
1731  Args:
1732    inputs: List of input tensors.
1733    outputs: List of output tensors.
1734
1735  Returns:
1736    A tuple of List{Node] and List[Layer].
1737  """
1738  base_layer_utils.create_keras_history(outputs)
1739  # Keep only nodes and layers in the topology between inputs and outputs.
1740  _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
1741  return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
1742
1743
1744def _should_skip_first_node(layer):
1745  """Returns True if the first layer node should not be saved or loaded."""
1746  # Networks start with a pre-existing node linking their input to output.
1747  return issubclass(layer.__class__, Network) and layer._is_graph_network
1748
1749
1750def _serialize_tensors(kwargs):
1751  """Serializes Tensors passed to `call`."""
1752
1753  def _serialize_keras_tensor(t):
1754    """Serializes a single Tensor passed to `call`."""
1755    if hasattr(t, '_keras_history'):
1756      kh = t._keras_history
1757      return [kh.layer.name, kh.node_index, kh.tensor_index]
1758
1759    if isinstance(t, np.ndarray):
1760      return t.tolist()
1761
1762    if isinstance(t, ops.Tensor):
1763      return backend.get_value(t).tolist()
1764
1765    return t
1766
1767  return nest.map_structure(_serialize_keras_tensor, kwargs)
1768
1769
1770def _map_tensors_to_constants(kwargs):
1771
1772  def _map_to_constants(t):
1773    if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor):
1774      return constant_op.constant(backend.get_value(t))
1775    return t
1776
1777  return nest.map_structure(_map_to_constants, kwargs)
1778
1779
1780def _deserialize_keras_tensors(kwargs, layer_map):
1781  """Deserializes Keras Tensors passed to `call`.."""
1782
1783  def _deserialize_keras_tensor(t):
1784    """Deserializes a single Keras Tensor passed to `call`."""
1785    if isinstance(t, tf_utils.ListWrapper):
1786      t = t.as_list()
1787      layer_name = t[0]
1788      node_index = t[1]
1789      tensor_index = t[2]
1790
1791      layer = layer_map[layer_name]
1792      node = layer._inbound_nodes[node_index]
1793      return nest.flatten(node.output_tensors)[tensor_index]
1794    return t
1795
1796  kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
1797  return nest.map_structure(_deserialize_keras_tensor, kwargs)
1798
1799
1800def connect_ancillary_layers(model, created_layers):
1801  """Adds layers that are not connected to the outputs to the model."""
1802  # Layers not connected to outputs, such as those added in `add_loss`.
1803  ancillary_layers = [
1804      layer for layer in created_layers.values() if layer not in model.layers
1805  ]
1806  if ancillary_layers:
1807    relevant_nodes = nest.flatten([
1808        layer.inbound_nodes[1:]
1809        if _should_skip_first_node(layer) else layer.inbound_nodes
1810        for layer in created_layers.values()
1811    ])
1812    model._insert_layers(ancillary_layers, relevant_nodes)
1813  return model
1814
1815
1816def reconstruct_from_config(config, custom_objects=None, created_layers=None):
1817  """Reconstructs graph from config object.
1818
1819  Args:
1820    config: Dictionary returned from Network.get_config()
1821    custom_objects: Optional dictionary mapping names (strings) to custom
1822      classes or functions to be considered during deserialization.
1823    created_layers: Optional dictionary mapping names to Layer objects. Any
1824      layer not in this dictionary will be be created and added to the dict.
1825      This function will add new nodes to all layers (excluding InputLayers),
1826      instead of re-using pre-existing nodes in the layers.
1827
1828  Returns:
1829    Tuple of (input tensors, output tensors, dictionary of created layers)
1830  """
1831  # Layer instances created during the graph reconstruction process.
1832  created_layers = created_layers or collections.OrderedDict()
1833
1834  # Maps input data (tuple of inbound layer name, node index) from the config
1835  # to node indices in the newly generated model. The node indices may be
1836  # different if the layers have already been called previously.
1837  node_index_map = {}
1838  node_count_by_layer = {}
1839
1840  # Dictionary mapping layer instances to
1841  # node data that specifies a layer call.
1842  # It acts as a queue that maintains any unprocessed
1843  # layer call until it becomes possible to process it
1844  # (i.e. until the input tensors to the call all exist).
1845  unprocessed_nodes = {}
1846
1847  def add_unprocessed_node(layer, node_data):
1848    if layer not in unprocessed_nodes:
1849      unprocessed_nodes[layer] = [node_data]
1850    else:
1851      unprocessed_nodes[layer].append(node_data)
1852
1853  def get_node_index(layer, config_node_index):
1854    """Returns node index in layer (might differ from config_node_index)."""
1855    if isinstance(layer, input_layer_module.InputLayer):
1856      return 0
1857    return node_index_map.get((layer.name, config_node_index), None)
1858
1859  def process_node(layer, node_data):
1860    """Deserialize a node.
1861
1862    Arguments:
1863        layer: layer instance.
1864        node_data: Nested structure of `ListWrapper`.
1865
1866    Raises:
1867        ValueError: In case of improperly formatted `node_data`.
1868    """
1869    input_tensors = []
1870    for input_data in nest.flatten(node_data):
1871      input_data = input_data.as_list()
1872      inbound_layer_name = input_data[0]
1873      inbound_node_index = input_data[1]
1874      inbound_tensor_index = input_data[2]
1875      if len(input_data) == 3:
1876        kwargs = {}
1877      elif len(input_data) == 4:
1878        kwargs = input_data[3]
1879        kwargs = _deserialize_keras_tensors(kwargs, created_layers)
1880      else:
1881        raise ValueError('Improperly formatted model config.')
1882
1883      inbound_layer = created_layers[inbound_layer_name]
1884      inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
1885
1886      if inbound_node_index is None:
1887        add_unprocessed_node(layer, node_data)
1888        return
1889      inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1890      input_tensors.append(
1891          nest.flatten(inbound_node.output_tensors)[inbound_tensor_index])
1892    input_tensors = nest.pack_sequence_as(node_data, input_tensors)
1893    # Call layer on its inputs, thus creating the node
1894    # and building the layer if needed.
1895    if input_tensors is not None:
1896      input_tensors = base_layer_utils.unnest_if_single_tensor(input_tensors)
1897      output_tensors = layer(input_tensors, **kwargs)
1898
1899      # Update node index map.
1900      output_index = nest.flatten(output_tensors)[0]._keras_history.node_index
1901      node_index_map[(layer.name, node_count_by_layer[layer])] = output_index
1902      node_count_by_layer[layer] += 1
1903
1904  def process_layer(layer_data):
1905    """Deserializes a layer, then call it on appropriate inputs.
1906
1907    Arguments:
1908        layer_data: layer config dict.
1909
1910    Raises:
1911        ValueError: In case of improperly formatted `layer_data` dict.
1912    """
1913    layer_name = layer_data['name']
1914
1915    if layer_name in created_layers:
1916      layer = created_layers[layer_name]
1917    else:
1918      # Instantiate layer.
1919      from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
1920
1921      layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1922      created_layers[layer_name] = layer
1923
1924    node_count_by_layer[layer] = int(_should_skip_first_node(layer))
1925
1926    # Gather layer inputs and convert to `ListWrapper` objects.
1927    inbound_nodes_data = layer_data['inbound_nodes']
1928    inbound_nodes_data = tf_utils.convert_inner_node_data(
1929        inbound_nodes_data, wrap=True)
1930    for node_data in inbound_nodes_data:
1931      # We don't process nodes (i.e. make layer calls)
1932      # on the fly because the inbound node may not yet exist,
1933      # in case of layer shared at different topological depths
1934      # (e.g. a model such as A(B(A(B(x)))))
1935      add_unprocessed_node(layer, node_data)
1936
1937  # First, we create all layers and enqueue nodes to be processed
1938  for layer_data in config['layers']:
1939    process_layer(layer_data)
1940  # Then we process nodes in order of layer depth.
1941  # Nodes that cannot yet be processed (if the inbound node
1942  # does not yet exist) are re-enqueued, and the process
1943  # is repeated until all nodes are processed.
1944  while unprocessed_nodes:
1945    for layer_data in config['layers']:
1946      layer = created_layers[layer_data['name']]
1947      if layer in unprocessed_nodes:
1948        for node_data in unprocessed_nodes.pop(layer):
1949          process_node(layer, node_data)
1950
1951  input_tensors = []
1952  output_tensors = []
1953
1954  input_layers = tf_utils.convert_inner_node_data(
1955      config['input_layers'], wrap=True)
1956  for layer_data in nest.flatten(input_layers):
1957    layer_name, node_index, tensor_index = layer_data.as_list()
1958    assert layer_name in created_layers
1959    layer = created_layers[layer_name]
1960    node_index = get_node_index(layer, node_index)
1961    layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1962    input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1963
1964  output_layers = tf_utils.convert_inner_node_data(
1965      config['output_layers'], wrap=True)
1966  for layer_data in nest.flatten(output_layers):
1967    layer_name, node_index, tensor_index = layer_data.as_list()
1968    assert layer_name in created_layers
1969    layer = created_layers[layer_name]
1970    node_index = get_node_index(layer, node_index)
1971    layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1972    output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1973
1974  input_tensors = nest.pack_sequence_as(input_layers, input_tensors)
1975  output_tensors = nest.pack_sequence_as(output_layers, output_tensors)
1976  return input_tensors, output_tensors, created_layers
1977
1978
1979def get_network_config(network, serialize_layer_fn=None):
1980  """Builds the config, which consists of the node graph and serialized layers.
1981
1982  Args:
1983    network: A Network object.
1984    serialize_layer_fn: Function used to serialize layers.
1985
1986  Returns:
1987    Config dictionary.
1988  """
1989  serialize_layer_fn = (
1990      serialize_layer_fn or generic_utils.serialize_keras_object)
1991  config = {
1992      'name': network.name,
1993  }
1994  node_conversion_map = {}
1995  for layer in network.layers:
1996    kept_nodes = 1 if _should_skip_first_node(layer) else 0
1997    for original_node_index, node in enumerate(layer._inbound_nodes):
1998      node_key = _make_node_key(layer.name, original_node_index)
1999      if node_key in network._network_nodes:
2000        node_conversion_map[node_key] = kept_nodes
2001        kept_nodes += 1
2002  layer_configs = []
2003  for layer in network.layers:  # From the earliest layers on.
2004    filtered_inbound_nodes = []
2005    for original_node_index, node in enumerate(layer._inbound_nodes):
2006      node_key = _make_node_key(layer.name, original_node_index)
2007      if node_key in network._network_nodes:
2008        # The node is relevant to the model:
2009        # add to filtered_inbound_nodes.
2010        if node.arguments:
2011          kwargs = _serialize_tensors(node.arguments)
2012          try:
2013            json.dumps(kwargs)
2014          except TypeError:
2015            logging.warning(
2016                'Layer ' + layer.name +
2017                ' was passed non-serializable keyword arguments: ' +
2018                str(node.arguments) + '. They will not be included '
2019                'in the serialized model (and thus will be missing '
2020                'at deserialization time).')
2021            kwargs = {}
2022        else:
2023          kwargs = {}
2024        if node.inbound_layers:
2025          node_data = []
2026          for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound():
2027            node_key = _make_node_key(inbound_layer.name, node_id)
2028            new_node_index = node_conversion_map.get(node_key, 0)
2029            node_data.append(
2030                tf_utils.ListWrapper(
2031                    [inbound_layer.name, new_node_index, tensor_id, kwargs]))
2032          node_data = nest.pack_sequence_as(node.input_tensors, node_data)
2033          if not nest.is_sequence(node_data):
2034            node_data = [node_data]
2035          # Convert ListWrapper to list for backwards compatible configs.
2036          node_data = tf_utils.convert_inner_node_data(node_data)
2037          filtered_inbound_nodes.append(node_data)
2038
2039    layer_config = serialize_layer_fn(layer)
2040    layer_config['name'] = layer.name
2041    layer_config['inbound_nodes'] = filtered_inbound_nodes
2042    layer_configs.append(layer_config)
2043  config['layers'] = layer_configs
2044
2045  # Gather info about inputs and outputs.
2046  model_inputs = []
2047  for i in range(len(network._input_layers)):
2048    layer, node_index, tensor_index = network._input_coordinates[i]
2049    node_key = _make_node_key(layer.name, node_index)
2050    if node_key not in network._network_nodes:
2051      continue
2052    new_node_index = node_conversion_map[node_key]
2053    model_inputs.append(
2054        tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
2055  model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs)
2056  # Preserve external Keras compat for Models with single input.
2057  if not nest.is_sequence(model_inputs):
2058    model_inputs = [model_inputs]
2059  model_inputs = tf_utils.convert_inner_node_data(model_inputs)
2060  config['input_layers'] = model_inputs
2061
2062  model_outputs = []
2063  for i in range(len(network._output_layers)):
2064    layer, node_index, tensor_index = network._output_coordinates[i]
2065    node_key = _make_node_key(layer.name, node_index)
2066    if node_key not in network._network_nodes:
2067      continue
2068    new_node_index = node_conversion_map[node_key]
2069    model_outputs.append(
2070        tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
2071  model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs)
2072  # Preserve external Keras compat for Models with single output.
2073  if not nest.is_sequence(model_outputs):
2074    model_outputs = [model_outputs]
2075  model_outputs = tf_utils.convert_inner_node_data(model_outputs)
2076  config['output_layers'] = model_outputs
2077  return config
2078