• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""A `Network` is way to compose layers: the topological form of a `Model`.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import copy
24import itertools
25import warnings
26
27from six.moves import zip  # pylint: disable=redefined-builtin
28
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.keras import backend
33from tensorflow.python.keras.engine import base_layer
34from tensorflow.python.keras.engine import base_layer_utils
35from tensorflow.python.keras.engine import input_layer as input_layer_module
36from tensorflow.python.keras.engine import input_spec
37from tensorflow.python.keras.engine import keras_tensor
38from tensorflow.python.keras.engine import node as node_module
39from tensorflow.python.keras.engine import training as training_lib
40from tensorflow.python.keras.engine import training_utils
41from tensorflow.python.keras.saving.saved_model import network_serialization
42from tensorflow.python.keras.utils import generic_utils
43from tensorflow.python.keras.utils import tf_inspect
44from tensorflow.python.keras.utils import tf_utils
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.training.tracking import base as trackable
49from tensorflow.python.util import nest
50from tensorflow.tools.docs import doc_controls
51
52
53# pylint: disable=g-classes-have-attributes
54class Functional(training_lib.Model):
55  """A `Functional` model is a `Model` defined as a directed graph of layers.
56
57  Three types of `Model` exist: subclassed `Model`, `Functional` model,
58  and `Sequential` (a special case of `Functional`).
59  In general, more Keras features are supported with `Functional`
60  than with subclassed `Model`s, specifically:
61
62  - Model cloning (`keras.models.clone`)
63  - Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()`
64  - Whole-model saving (`model.save()`)
65
66  A `Functional` model can be instantiated by passing two arguments to
67  `__init__`. The first argument is the `keras.Input` Tensors that represent
68  the inputs to the model. The second argument specifies the output
69  tensors that represent the outputs of this model. Both arguments can be a
70  nested structure of tensors.
71
72  Example:
73
74  ```
75  inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
76  t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
77  outputs = keras.layers.Add()([t, inputs['x2'])
78  model = keras.Model(inputs, outputs)
79  ```
80
81  A `Functional` model constructed using the Functional API can also include raw
82  TensorFlow functions, with the exception of functions that create Variables
83  or assign ops.
84
85  Example:
86
87  ```
88  inputs = keras.Input(shape=(10,))
89  x = keras.layers.Dense(1)(inputs)
90  outputs = tf.nn.relu(x)
91  model = keras.Model(inputs, outputs)
92  ```
93
94  Args:
95    inputs: List of input tensors (must be created via `tf.keras.Input()`).
96    outputs: List of output tensors.
97    name: String, optional. Name of the model.
98    trainable: Boolean, optional. If the model's variables should be trainable.
99  """
100
101  # See tf.Module for the usage of this property.
102  # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to
103  # flatten the key since it is trying to convert Trackable/Layer to a string.
104  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
105      ('_layer_call_argspecs', '_compiled_trainable_state',
106       '_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'),
107      training_lib.Model._TF_MODULE_IGNORED_PROPERTIES
108  ))
109
110  @trackable.no_automatic_dependency_tracking
111  def __init__(self, inputs, outputs, name=None, trainable=True,
112               **kwargs):
113    # This is used by the Model class, since we have some logic to swap the
114    # class in the __new__ method, which will lead to __init__ get invoked
115    # twice. Using the skip_init to skip one of the invocation of __init__ to
116    # avoid any side effects
117    skip_init = kwargs.pop('skip_init', False)
118    if skip_init:
119      return
120    generic_utils.validate_kwargs(kwargs, {})
121    super(Functional, self).__init__(name=name, trainable=trainable)
122    self._init_graph_network(inputs, outputs)
123
124  @trackable.no_automatic_dependency_tracking
125  def _init_graph_network(self, inputs, outputs):
126    base_layer.keras_api_gauge.get_cell('Functional').set(True)
127    # This method is needed for Sequential to reinitialize graph network when
128    # layer is added or removed.
129    self._is_graph_network = True
130
131    # Normalize and set self.inputs, self.outputs.
132    if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1:
133      inputs = inputs[0]
134    if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1:
135      outputs = outputs[0]
136    self._nested_inputs = inputs
137    self._nested_outputs = outputs
138    self.inputs = nest.flatten(inputs)
139    self.outputs = nest.flatten(outputs)
140
141    # Models constructed with a single Tensor or list of Tensors can
142    # be called with a dict, where the keys of the dict are the names
143    # of the `Input` objects. Extra keys are ignored with warning.
144    if not nest.is_nested(self._nested_inputs):
145      self._enable_dict_to_input_mapping = True
146    elif (isinstance(self._nested_inputs, (list, tuple)) and
147          not any(nest.is_nested(t) for t in self._nested_inputs)):
148      self._enable_dict_to_input_mapping = True
149    elif (isinstance(self._nested_inputs, dict) and
150          not any(nest.is_nested(t) for t in self._nested_inputs.values())):
151      self._enable_dict_to_input_mapping = True
152    else:
153      self._enable_dict_to_input_mapping = False
154
155    if not keras_tensor.keras_tensors_enabled():
156      if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
157        base_layer_utils.create_keras_history(self._nested_outputs)
158
159    self._validate_graph_inputs_and_outputs()
160
161    # A Network does not create weights of its own, thus it is already
162    # built.
163    self.built = True
164    self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs)
165    self._compute_output_and_mask_jointly = True
166    # `_expects_training_arg` is True since the `training` argument is always
167    # present in the signature of the `call` method of a graph network.
168    self._expects_training_arg = True
169    self._expects_mask_arg = True
170    # A graph network does not autocast inputs, as its layers will cast them
171    # instead.
172    self._autocast = False
173
174    self._input_layers = []
175    self._output_layers = []
176    self._input_coordinates = []
177    self._output_coordinates = []
178
179    # This is for performance optimization when calling the Network on new
180    # inputs. Every time the Network is called on a set on input tensors,
181    # we compute the output tensors, output masks and output shapes in one pass,
182    # then cache them here. When any of these outputs is queried later, we
183    # retrieve it from there instead of recomputing it.
184    self._output_mask_cache = {}
185    self._output_tensor_cache = {}
186    self._output_shape_cache = {}
187
188    # Build self._output_layers:
189    for x in self.outputs:
190      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
191      self._output_layers.append(layer)
192      self._output_coordinates.append((layer, node_index, tensor_index))
193
194    # Build self._input_layers:
195    for x in self.inputs:
196      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
197      # It's supposed to be an input layer, so only one node
198      # and one tensor output.
199      assert node_index == 0
200      assert tensor_index == 0
201      self._input_layers.append(layer)
202      self._input_coordinates.append((layer, node_index, tensor_index))
203
204    # Keep track of the network's nodes and layers.
205    nodes, nodes_by_depth, layers, _ = _map_graph_network(
206        self.inputs, self.outputs)
207    self._network_nodes = nodes
208    self._nodes_by_depth = nodes_by_depth
209    self._self_tracked_trackables = layers
210    self._layer_call_argspecs = {}
211    for layer in self._self_tracked_trackables:
212      self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
213
214    # Build self.input_names and self.output_names.
215    self._set_output_names()
216    self.input_names = []
217    self._feed_input_names = []
218    self._feed_inputs = []
219    self._feed_input_shapes = []
220    for layer in self._input_layers:
221      self.input_names.append(layer.name)
222      if layer.is_placeholder:
223        self._feed_input_names.append(layer.name)
224        # Use batch_input_shape here because non-eager composite tensors may not
225        # have a shape attribute that's meaningful (sparse, for instance, has
226        # a tensor that's non-constant and needs to be fed). This means that
227        # input layers that create placeholders will need to have the
228        # batch_input_shape attr to allow for input shape validation.
229        self._feed_input_shapes.append(layer._batch_input_shape)
230        self._feed_inputs.append(layer.input)
231
232    self._compute_tensor_usage_count()
233    self._set_save_spec(self._nested_inputs)
234    tf_utils.assert_no_legacy_layers(self.layers)
235
236  @property
237  def input(self):
238    """Retrieves the input tensor(s) of a layer.
239
240    Only applicable if the layer has exactly one input,
241    i.e. if it is connected to one incoming layer.
242
243    Returns:
244        Input tensor or list of input tensors.
245
246    Raises:
247      RuntimeError: If called in Eager mode.
248      AttributeError: If no inbound nodes are found.
249    """
250    return self._nested_inputs
251
252  @property
253  def input_shape(self):
254    """Retrieves the input shape(s) of a layer.
255
256    Only applicable if the layer has exactly one input,
257    i.e. if it is connected to one incoming layer, or if all inputs
258    have the same shape.
259
260    Returns:
261        Input shape, as an integer shape tuple
262        (or list of shape tuples, one tuple per input tensor).
263
264    Raises:
265        AttributeError: if the layer has no defined input_shape.
266        RuntimeError: if called in Eager mode.
267    """
268    return nest.map_structure(backend.int_shape, self.input)
269
270  @property
271  def input_spec(self):
272    if hasattr(self, '_manual_input_spec'):
273      return self._manual_input_spec
274    if (isinstance(self._nested_inputs, (dict, list, tuple)) and
275        len(self._nested_inputs) != len(self.inputs)):
276      # Case where we have a nested structure.
277      # In such a case we can't safely run any checks.
278      return None
279    if isinstance(self._nested_inputs, dict):
280      # Case where `_nested_inputs` is a plain dict of Inputs.
281      names = sorted(self._nested_inputs.keys())
282      return [input_spec.InputSpec(
283          shape=shape_with_no_batch_size(self._nested_inputs[name]),
284          allow_last_axis_squeeze=True, name=name) for name in names]
285    else:
286      # Single input, or list / tuple of inputs.
287      # The data may be passed as a dict keyed by input name.
288      return [input_spec.InputSpec(
289          shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True,
290          name=x._keras_history.layer.name) for x in self.inputs]
291
292  @input_spec.setter
293  def input_spec(self, value):
294    self._manual_input_spec = value
295
296  @property
297  def output(self):
298    """Retrieves the output tensor(s) of a layer.
299
300    Only applicable if the layer has exactly one output,
301    i.e. if it is connected to one incoming layer.
302
303    Returns:
304      Output tensor or list of output tensors.
305
306    Raises:
307      AttributeError: if the layer is connected to more than one incoming
308        layers.
309      RuntimeError: if called in Eager mode.
310    """
311    return self._nested_outputs
312
313  @property
314  def output_shape(self):
315    """Retrieves the output shape(s) of a layer.
316
317    Only applicable if the layer has one output,
318    or if all outputs have the same shape.
319
320    Returns:
321        Output shape, as an integer shape tuple
322        (or list of shape tuples, one tuple per output tensor).
323
324    Raises:
325        AttributeError: if the layer has no defined output shape.
326        RuntimeError: if called in Eager mode.
327    """
328    return nest.map_structure(backend.int_shape, self.output)
329
330  def _set_output_names(self):
331    """Assigns unique names to the Network's outputs.
332
333    Output layers with multiple output tensors would otherwise lead to duplicate
334    names in self.output_names.
335    """
336    uniquified = []
337    output_names = set()
338    prefix_count = {}
339    for layer in self._output_layers:
340      proposal = layer.name
341      while proposal in output_names:
342        existing_count = prefix_count.get(layer.name, 1)
343        proposal = '{}_{}'.format(layer.name, existing_count)
344        prefix_count[layer.name] = existing_count + 1
345      output_names.add(proposal)
346      uniquified.append(proposal)
347    self.output_names = uniquified
348
349  @property
350  def _layer_checkpoint_dependencies(self):
351    """Dictionary of layer dependencies to be included in the checkpoint."""
352    weight_layer_index = 0
353
354    dependencies = collections.OrderedDict()
355    for layer_index, layer in enumerate(self.layers):
356      try:
357        if layer.weights:
358          # Keep a separate index for layers which have weights. This allows
359          # users to insert Layers without weights anywhere in the network
360          # without breaking checkpoints.
361          dependencies['layer_with_weights-%d' % weight_layer_index] = layer
362          weight_layer_index += 1
363      except ValueError:
364        # The layer might have weights, but may not be built yet. We just treat
365        # it as layer without weight.
366        pass
367
368      # Even if it doesn't have weights, we should still track everything in
369      # case it has/will have Trackable dependencies.
370      dependencies['layer-%d' % layer_index] = layer
371    return dependencies
372
373  @property
374  def _checkpoint_dependencies(self):
375    dependencies = [
376        trackable.TrackableReference(name=name, ref=layer)
377        for name, layer in self._layer_checkpoint_dependencies.items()]
378    dependencies.extend(super(Functional, self)._checkpoint_dependencies)
379    return dependencies
380
381  def _lookup_dependency(self, name):
382    layer_dependencies = self._layer_checkpoint_dependencies
383    if name in layer_dependencies:
384      return layer_dependencies[name]
385    return super(Functional, self)._lookup_dependency(name)
386
387  def _handle_deferred_layer_dependencies(self, layers):
388    """Handles layer checkpoint dependencies that are added after init."""
389    layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
390    layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
391    for layer in layers:
392      if layer in layer_to_name:
393        self._handle_deferred_dependencies(name=layer_to_name[layer],
394                                           trackable=layer)
395
396  @property
397  def _should_compute_mask(self):
398    return True
399
400  def compute_mask(self, inputs, mask):
401    # TODO(omalleyt): b/123540974 This function is not really safe to call
402    # by itself because it will duplicate any updates and losses in graph
403    # mode by `call`ing the Layers again.
404    output_tensors = self._run_internal_graph(inputs, mask=mask)
405    return nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
406                              output_tensors)
407
408  @doc_controls.do_not_doc_inheritable
409  def call(self, inputs, training=None, mask=None):
410    """Calls the model on new inputs.
411
412    In this case `call` just reapplies
413    all ops in the graph to the new inputs
414    (e.g. build a new computational graph from the provided inputs).
415
416    Args:
417        inputs: A tensor or list of tensors.
418        training: Boolean or boolean scalar tensor, indicating whether to run
419          the `Network` in training mode or inference mode.
420        mask: A mask or list of masks. A mask can be
421            either a tensor or None (no mask).
422
423    Returns:
424        A tensor if there is a single output, or
425        a list of tensors if there are more than one outputs.
426    """
427    return self._run_internal_graph(
428        inputs, training=training, mask=mask)
429
430  def compute_output_shape(self, input_shape):
431    # Convert any shapes in tuple format to TensorShapes.
432    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
433
434    if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)):
435      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
436                       ': model has ' + str(len(self._input_layers)) +
437                       ' tensor inputs.')
438
439    # Use the tuple of TensorShape as the cache key, since tuple is hashable
440    # and can be used as hash key.
441    try:
442      cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True))
443      if cache_key in self._output_shape_cache:
444        # Cache hit. Return shapes as TensorShapes.
445        return self._output_shape_cache[cache_key]
446    except ValueError:
447      # In case there are unknown TensorShape, eg for sparse tensor input,
448      # We skip the caching since the shape is unknown.
449      pass
450
451    layers_to_output_shapes = {}
452    for layer, shape in zip(self._input_layers, nest.flatten(input_shape)):
453      # It's an input layer: then `compute_output_shape` is identity,
454      # and there is only one node and one tensor..
455      shape_key = layer.name + '_0_0'
456      layers_to_output_shapes[shape_key] = shape
457
458    depth_keys = list(self._nodes_by_depth.keys())
459    depth_keys.sort(reverse=True)
460    # Iterate over nodes, by depth level.
461    if len(depth_keys) > 1:
462      for depth in depth_keys:
463        nodes = self._nodes_by_depth[depth]
464        for node in nodes:
465          layer = node.layer
466          if layer in self._input_layers:
467            # We've already covered the input layers
468            # a few lines above.
469            continue
470          # Get the input shapes for the first argument of the node
471          layer_input_shapes = []
472          layer_inputs = node.call_args[0]
473          for layer_input in nest.flatten(layer_inputs):
474            kh = layer_input._keras_history
475            input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index,
476                                                          kh.tensor_index)
477            layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
478          layer_input_shapes = nest.pack_sequence_as(layer_inputs,
479                                                     layer_input_shapes)
480          # Layers expect shapes to be tuples for `compute_output_shape`.
481          layer_input_shapes = tf_utils.convert_shapes(
482              layer_input_shapes, to_tuples=True)
483          layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
484          # Convert back to TensorShapes.
485          layer_output_shapes = tf_utils.convert_shapes(
486              layer_output_shapes, to_tuples=False)
487
488          node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
489          for j, shape in enumerate(nest.flatten(layer_output_shapes)):
490            shape_key = layer.name + '_%s_%s' % (node_index, j)
491            layers_to_output_shapes[shape_key] = shape
492
493      # Read final output shapes from layers_to_output_shapes.
494      output_shapes = []
495      for i in range(len(self._output_layers)):
496        layer, node_index, tensor_index = self._output_coordinates[i]
497        shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
498        output_shapes.append(layers_to_output_shapes[shape_key])
499      output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes)
500      # Store in cache.
501      self._output_shape_cache[cache_key] = output_shapes
502
503    # Return shapes as TensorShapes.
504    return output_shapes
505
506  def _init_set_name(self, name, zero_based=True):
507    if not name:
508      cls_name = self.__class__.__name__
509      if self.__class__ == Functional:
510        # Hide the functional class name from user, since its not a public
511        # visible class. Use "Model" instead,
512        cls_name = 'Model'
513      self._name = backend.unique_object_name(
514          generic_utils.to_snake_case(cls_name),
515          zero_based=zero_based)
516    else:
517      self._name = name
518
519  def _run_internal_graph(self, inputs, training=None, mask=None):
520    """Computes output tensors for new inputs.
521
522    # Note:
523        - Can be run on non-Keras tensors.
524
525    Args:
526        inputs: Tensor or nested structure of Tensors.
527        training: Boolean learning phase.
528        mask: (Optional) Tensor or nested structure of Tensors.
529
530    Returns:
531        output_tensors
532    """
533    inputs = self._flatten_to_reference_inputs(inputs)
534    if mask is None:
535      masks = [None] * len(inputs)
536    else:
537      masks = self._flatten_to_reference_inputs(mask)
538    for input_t, mask in zip(inputs, masks):
539      input_t._keras_mask = mask
540
541    # Dictionary mapping reference tensors to computed tensors.
542    tensor_dict = {}
543    tensor_usage_count = self._tensor_usage_count
544    for x, y in zip(self.inputs, inputs):
545      y = self._conform_to_reference_input(y, ref_input=x)
546      x_id = str(id(x))
547      tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
548
549    nodes_by_depth = self._nodes_by_depth
550    depth_keys = list(nodes_by_depth.keys())
551    depth_keys.sort(reverse=True)
552
553    for depth in depth_keys:
554      nodes = nodes_by_depth[depth]
555      for node in nodes:
556        if node.is_input:
557          continue  # Input tensors already exist.
558
559        if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
560          continue  # Node is not computable, try skipping.
561
562        args, kwargs = node.map_arguments(tensor_dict)
563        outputs = node.layer(*args, **kwargs)
564
565        # Update tensor_dict.
566        for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
567          tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
568
569    output_tensors = []
570    for x in self.outputs:
571      x_id = str(id(x))
572      assert x_id in tensor_dict, 'Could not compute output ' + str(x)
573      output_tensors.append(tensor_dict[x_id].pop())
574
575    return nest.pack_sequence_as(self._nested_outputs, output_tensors)
576
577  def _flatten_to_reference_inputs(self, tensors):
578    """Maps `tensors` to their respective `keras.Input`."""
579    if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
580      ref_inputs = self._nested_inputs
581      if not nest.is_nested(ref_inputs):
582        ref_inputs = [self._nested_inputs]
583      if isinstance(ref_inputs, dict):
584        # In the case that the graph is constructed with dict input tensors,
585        # We will use the original dict key to map with the keys in the input
586        # data. Note that the model.inputs is using nest.flatten to process the
587        # input tensors, which means the dict input tensors are ordered by their
588        # keys.
589        ref_input_names = sorted(ref_inputs.keys())
590      else:
591        ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs]
592
593      # Raise an warning if there are more input data comparing to input tensor
594      if len(tensors) > len(ref_input_names):
595        warnings.warn(
596            'Input dict contained keys {} which did not match any model input. '
597            'They will be ignored by the model.'.format(
598                [n for n in tensors.keys() if n not in ref_input_names])
599            )
600
601      try:
602        # Flatten in the order `Input`s were passed during Model construction.
603        return [tensors[n] for n in ref_input_names]
604      except KeyError:
605        # TODO(b/151582614)
606        return nest.flatten(tensors)
607
608    # Otherwise both self.inputs and tensors will already be in same order.
609    return nest.flatten(tensors)
610
611  def _conform_to_reference_input(self, tensor, ref_input):
612    """Set shape and dtype based on `keras.Input`s."""
613    if isinstance(tensor, ops.Tensor):
614      # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use
615      # the shape specified by the `keras.Input`.
616      t_shape = tensor.shape
617      t_rank = t_shape.rank
618      ref_shape = ref_input.shape
619      ref_rank = ref_shape.rank
620      keras_history = getattr(tensor, '_keras_history', None)
621      if t_rank is not None and ref_rank is not None:
622        # Should squeeze last dimension.
623        # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
624        if (t_rank == ref_rank + 1 and t_shape[-1] == 1):
625          tensor = array_ops.squeeze_v2(tensor, axis=-1)
626        # Should expand last_dimension.
627        # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
628        elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
629          tensor = array_ops.expand_dims_v2(tensor, axis=-1)
630      if keras_history is not None:  # Restore keras history.
631        tensor._keras_history = keras_history
632
633      # Add shape hints to Tensors that may have None shape dims but have shapes
634      # defined by the `keras.Input` (not applicable in eager mode).
635      if not context.executing_eagerly():
636        try:
637          tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
638        except ValueError:
639          logging.warning(
640              'Model was constructed with shape {} for input {}, but it was '
641              'called on an input with incompatible shape {}.'.format(
642                  ref_input.shape, ref_input, tensor.shape))
643
644      # Dtype casting.
645      tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
646    elif tf_utils.is_extension_type(tensor):
647      # Dtype casting (If the extension type has a non-variant dtype and
648      # supports being cast)
649      ref_input_dtype = getattr(ref_input, 'dtype', None)
650      if ref_input_dtype is not None and ref_input_dtype != dtypes.variant:
651        tensor = math_ops.cast(tensor, dtype=ref_input_dtype)
652
653    return tensor
654
655  def get_config(self):
656    return copy.deepcopy(get_network_config(self))
657
658  @classmethod
659  def from_config(cls, config, custom_objects=None):
660    """Instantiates a Model from its config (output of `get_config()`).
661
662    Args:
663        config: Model config dictionary.
664        custom_objects: Optional dictionary mapping names
665            (strings) to custom classes or functions to be
666            considered during deserialization.
667
668    Returns:
669        A model instance.
670
671    Raises:
672        ValueError: In case of improperly formatted config dict.
673    """
674    with generic_utils.SharedObjectLoadingScope():
675      input_tensors, output_tensors, created_layers = reconstruct_from_config(
676          config, custom_objects)
677      model = cls(inputs=input_tensors, outputs=output_tensors,
678                  name=config.get('name'))
679      connect_ancillary_layers(model, created_layers)
680      return model
681
682  def _validate_graph_inputs_and_outputs(self):
683    """Validates the inputs and outputs of a Graph Network."""
684    # Check for redundancy in inputs.
685    if len({id(i) for i in self.inputs}) != len(self.inputs):
686      raise ValueError('The list of inputs passed to the model '
687                       'is redundant. '
688                       'All inputs should only appear once.'
689                       ' Found: ' + str(self.inputs))
690
691    for x in self.inputs:
692      # Check that x has appropriate `_keras_history` metadata.
693      if not hasattr(x, '_keras_history'):
694        cls_name = self.__class__.__name__
695        raise ValueError('Input tensors to a ' + cls_name + ' ' +
696                         'must come from `tf.keras.Input`. '
697                         'Received: ' + str(x) +
698                         ' (missing previous layer metadata).')
699      # Check that x is an input tensor.
700      # pylint: disable=protected-access
701      layer = x._keras_history.layer
702      if len(layer._inbound_nodes) > 1 or (
703          layer._inbound_nodes and not layer._inbound_nodes[0].is_input):
704        cls_name = self.__class__.__name__
705        logging.warning(cls_name + ' model inputs must come from '
706                        '`tf.keras.Input` (thus holding past layer metadata), '
707                        'they cannot be the output of '
708                        'a previous non-Input layer. '
709                        'Here, a tensor specified as '
710                        'input to "' + self.name + '" was not an Input tensor, '
711                        'it was generated by layer ' + layer.name + '.\n'
712                        'Note that input tensors are '
713                        'instantiated via `tensor = tf.keras.Input(shape)`.\n'
714                        'The tensor that caused the issue was: ' + str(x.name))
715
716    # Check compatibility of batch sizes of Input Layers.
717    input_batch_sizes = [
718        training_utils.get_static_batch_size(x._keras_history.layer)
719        for x in self.inputs
720    ]
721    consistent_batch_size = None
722    for batch_size in input_batch_sizes:
723      if batch_size is not None:
724        if (consistent_batch_size is not None and
725            batch_size != consistent_batch_size):
726          raise ValueError('The specified batch sizes of the Input Layers'
727                           ' are incompatible. Found batch sizes: {}'.format(
728                               input_batch_sizes))
729        consistent_batch_size = batch_size
730
731    for x in self.outputs:
732      if not hasattr(x, '_keras_history'):
733        cls_name = self.__class__.__name__
734        raise ValueError('Output tensors of a ' + cls_name + ' model must be '
735                         'the output of a TensorFlow `Layer` '
736                         '(thus holding past layer metadata). Found: ' + str(x))
737
738  def _insert_layers(self, layers, relevant_nodes=None):
739    """Inserts Layers into the Network after Network creation.
740
741    This is only valid for Keras Graph Networks.  Layers added via this function
742    will be included in the `call` computation and `get_config` of this Network.
743    They will not be added to the Network's outputs.
744
745
746    Args:
747      layers: Arbitrary nested structure of Layers. Layers must be reachable
748        from one or more of the `keras.Input` Tensors that correspond to this
749        Network's inputs.
750      relevant_nodes: Nodes from the Layers that should be considered part of
751        this Network. If `None`, all Nodes will be considered part of this
752        Network.
753
754    Raises:
755      ValueError: If the layers depend on `Input`s not found in this Model.
756    """
757    layers = nest.flatten(layers)
758    tf_utils.assert_no_legacy_layers(layers)
759    node_to_depth = {}
760    for depth, nodes in self._nodes_by_depth.items():
761      node_to_depth.update({node: depth for node in nodes})
762    # The nodes of these Layers that are relevant to this Network. If not
763    # provided, assume all Nodes are relevant
764    if not relevant_nodes:
765      relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers])
766    network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
767
768    def _get_min_depth(node):
769      """Gets the minimum depth at which node can be computed."""
770      min_depth = 0
771      for layer, node_id, _, _ in node.iterate_inbound():
772        inbound_node = layer._inbound_nodes[node_id]
773        if inbound_node in node_to_depth:
774          min_depth = min(min_depth, node_to_depth[inbound_node])
775        elif inbound_node not in network_nodes:
776          continue
777        else:
778          # Previous relevant nodes haven't been processed yet.
779          return None
780      # New node is one shallower than its shallowest input.
781      return min_depth - 1
782
783    # Insert nodes into `_nodes_by_depth` and other node attrs.
784    unprocessed_nodes = copy.copy(relevant_nodes)
785    i = 0
786    while unprocessed_nodes:
787      i += 1
788      # Do a sanity check. This can occur if `Input`s from outside this Model
789      # are being relied on.
790      if i > 10000:
791        raise ValueError('Layers could not be added due to missing '
792                         'dependencies.')
793
794      node = unprocessed_nodes.pop(0)
795      depth = _get_min_depth(node)
796      if depth is None:  # Defer until inbound nodes are processed.
797        unprocessed_nodes.append(node)
798        continue
799      node_key = _make_node_key(node.layer.name,
800                                node.layer._inbound_nodes.index(node))
801      if node_key not in self._network_nodes:
802        node_to_depth[node] = depth
803        self._network_nodes.add(node_key)
804        self._nodes_by_depth[depth].append(node)
805
806    # Insert layers and update other layer attrs.
807    layer_set = set(self._self_tracked_trackables)
808    deferred_layers = []
809    for layer in layers:
810      if layer not in layer_set:
811        self._self_tracked_trackables.append(layer)
812        deferred_layers.append(layer)
813        self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
814        layer_set.add(layer)
815    self._handle_deferred_layer_dependencies(deferred_layers)
816
817    self._compute_tensor_usage_count()
818
819  def _compute_tensor_usage_count(self):
820    """Compute the #. of tensor usages for all the output tensors of layers.
821
822    The computed tensor usage count is saved as `self._tensor_usage_count`. This
823    is later used for saving memory in eager computation by releasing
824    no-longer-needed tensors as early as possible.
825    """
826    tensor_usage_count = collections.Counter()
827    available_tensors = set(str(id(tensor)) for tensor in self.inputs)
828
829    depth_keys = list(self._nodes_by_depth.keys())
830    depth_keys.sort(reverse=True)
831    depth_keys = depth_keys[1:]
832
833    for depth in depth_keys:
834      for node in self._nodes_by_depth[depth]:
835        input_tensors = {
836            str(id(tensor)) for tensor in nest.flatten(node.keras_inputs)
837        }
838        if input_tensors.issubset(available_tensors):
839          for tensor in nest.flatten(node.keras_inputs):
840            tensor_usage_count[str(id(tensor))] += 1
841
842          for output_tensor in nest.flatten(node.outputs):
843            available_tensors.add(str(id(output_tensor)))
844
845    for tensor in self.outputs:
846      tensor_usage_count[str(id(tensor))] += 1
847
848    self._tensor_usage_count = tensor_usage_count
849
850  def _assert_weights_created(self):
851    # Override the implementation in Model.
852    # The Functional model should always have weight created already.
853    return
854
855  def _graph_network_add_loss(self, symbolic_loss):
856    new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
857    # Losses must be keyed on inputs no matter what in order to be supported in
858    # DistributionStrategy.
859    add_loss_layer = base_layer.AddLoss(
860        unconditional=False, dtype=symbolic_loss.dtype)
861    add_loss_layer(symbolic_loss)
862    new_nodes.extend(add_loss_layer.inbound_nodes)
863    new_layers.append(add_loss_layer)
864    self._insert_layers(new_layers, new_nodes)
865
866  def _graph_network_add_metric(self, value, aggregation, name):
867    new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
868    add_metric_layer = base_layer.AddMetric(
869        aggregation, name, dtype=value.dtype)
870    add_metric_layer(value)
871    new_nodes.extend(add_metric_layer.inbound_nodes)
872    new_layers.append(add_metric_layer)
873    self._insert_layers(new_layers, new_nodes)
874
875  @property
876  def _trackable_saved_model_saver(self):
877    return network_serialization.NetworkSavedModelSaver(self)
878
879  def _get_save_spec(self, dynamic_batch=True):
880    if getattr(self, '_has_explicit_input_shape', True):
881      # Functional models and Sequential models that have an explicit input
882      # shape should use the batch size set by the input layer.
883      dynamic_batch = False
884    return super(Functional, self)._get_save_spec(dynamic_batch)
885
886
887def _make_node_key(layer_name, node_index):
888  return layer_name + '_ib-' + str(node_index)
889
890
891def _map_graph_network(inputs, outputs):
892  """Validates a network's topology and gather its layers and nodes.
893
894  Args:
895    inputs: List of input tensors.
896    outputs: List of outputs tensors.
897
898  Returns:
899    A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
900    - nodes: list of Node instances.
901    - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
902    - layers: list of Layer instances.
903    - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
904
905  Raises:
906    ValueError: In case the network is not valid (e.g. disconnected graph).
907  """
908  # "depth" is number of layers between output Node and the Node.
909  # Nodes are ordered from inputs -> outputs.
910  nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
911  network_nodes = {
912      _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
913      for node in nodes_in_decreasing_depth
914  }
915
916  nodes_depths = {}  # dict {node: depth value}
917  layers_depths = {}  # dict {layer: depth value}
918
919  for node in reversed(nodes_in_decreasing_depth):
920    # If the depth is not set, the node has no outbound nodes (depth 0).
921    depth = nodes_depths.setdefault(node, 0)
922
923    # Update the depth of the corresponding layer
924    previous_depth = layers_depths.get(node.layer, 0)
925    # If we've seen this layer before at a higher depth,
926    # we should use that depth instead of the node depth.
927    # This is necessary for shared layers that have inputs at different
928    # depth levels in the graph.
929    depth = max(depth, previous_depth)
930    layers_depths[node.layer] = depth
931    nodes_depths[node] = depth
932
933    # Update the depth of inbound nodes.
934    # The "depth" of a node is the max of the depths
935    # of all nodes it is connected to + 1.
936    for node_dep in node.parent_nodes:
937      previous_depth = nodes_depths.get(node_dep, 0)
938      nodes_depths[node_dep] = max(depth + 1, previous_depth)
939
940  # Handle inputs that are not connected to outputs.
941  # We do not error out here because the inputs may be used to compute losses
942  # and metrics.
943  for input_t in inputs:
944    input_layer = input_t._keras_history[0]
945    if input_layer not in layers_depths:
946      layers_depths[input_layer] = 0
947      layer_indices[input_layer] = -1
948      nodes_depths[input_layer._inbound_nodes[0]] = 0
949      network_nodes.add(_make_node_key(input_layer.name, 0))
950
951  # Build a dict {depth: list of nodes with this depth}
952  nodes_by_depth = collections.defaultdict(list)
953  for node, depth in nodes_depths.items():
954    nodes_by_depth[depth].append(node)
955
956  # Build a dict {depth: list of layers with this depth}
957  layers_by_depth = collections.defaultdict(list)
958  for layer, depth in layers_depths.items():
959    layers_by_depth[depth].append(layer)
960
961  # Get sorted list of layer depths.
962  depth_keys = list(layers_by_depth.keys())
963  depth_keys.sort(reverse=True)
964
965  # Set self.layers ordered by depth.
966  layers = []
967  for depth in depth_keys:
968    layers_for_depth = layers_by_depth[depth]
969    # Network.layers needs to have a deterministic order:
970    # here we order them by traversal order.
971    layers_for_depth.sort(key=lambda x: layer_indices[x])
972    layers.extend(layers_for_depth)
973
974  # Get sorted list of node depths.
975  depth_keys = list(nodes_by_depth.keys())
976  depth_keys.sort(reverse=True)
977
978  # Check that all tensors required are computable.
979  # computable_tensors: all tensors in the graph
980  # that can be computed from the inputs provided.
981  computable_tensors = set()
982  for x in inputs:
983    computable_tensors.add(id(x))
984
985  layers_with_complete_input = []  # To provide a better error msg.
986  for depth in depth_keys:
987    for node in nodes_by_depth[depth]:
988      layer = node.layer
989      if layer and not node.is_input:
990        for x in nest.flatten(node.keras_inputs):
991          if id(x) not in computable_tensors:
992            raise ValueError('Graph disconnected: '
993                             'cannot obtain value for tensor ' + str(x) +
994                             ' at layer "' + layer.name + '". '
995                             'The following previous layers '
996                             'were accessed without issue: ' +
997                             str(layers_with_complete_input))
998        for x in nest.flatten(node.outputs):
999          computable_tensors.add(id(x))
1000        layers_with_complete_input.append(layer.name)
1001
1002  # Ensure name unicity, which will be crucial for serialization
1003  # (since serialized nodes refer to layers by their name).
1004  all_names = [layer.name for layer in layers]
1005  for name in all_names:
1006    if all_names.count(name) != 1:
1007      raise ValueError('The name "' + name + '" is used ' +
1008                       str(all_names.count(name)) + ' times in the model. '
1009                       'All layer names should be unique.')
1010  return network_nodes, nodes_by_depth, layers, layers_by_depth
1011
1012
1013def _build_map(outputs):
1014  """This method topologically sorts nodes in order from inputs to outputs.
1015
1016  It uses a depth-first search to topologically sort nodes that appear in the
1017  _keras_history connectivity metadata of `outputs`.
1018
1019  Args:
1020    outputs: the output tensors whose _keras_history metadata should be walked.
1021    This may be an arbitrary nested structure.
1022
1023  Returns:
1024    A tuple like (ordered_nodes, layer_to_first_traversal_index)
1025    ordered_nodes: list of nodes appearing in the keras history, topologically
1026      sorted from original inputs to the `outputs`.
1027      (If outputs have different sets of ancestors, the inputs to one output
1028      may appear after a different output).
1029    layer_to_first_traversal_index:
1030      A dict mapping layer to the traversal index in the DFS where it is
1031      seen. Note: if a layer is shared by several nodes, the dict will only
1032      store the index corresponding to the *first* time the layer seen.
1033  """
1034  finished_nodes = set()
1035  nodes_in_progress = set()
1036  nodes_in_decreasing_depth = []  # nodes from inputs -> outputs.
1037  layer_indices = {}  # layer -> in traversal order.
1038  for output in nest.flatten(outputs):
1039    _build_map_helper(output, finished_nodes, nodes_in_progress,
1040                      nodes_in_decreasing_depth, layer_indices)
1041  return nodes_in_decreasing_depth, layer_indices
1042
1043
1044def _build_map_helper(tensor, finished_nodes, nodes_in_progress,
1045                      nodes_in_decreasing_depth, layer_indices):
1046  """Recursive helper for `_build_map`."""
1047  layer, node_index, _ = tensor._keras_history  # pylint: disable=protected-access
1048  node = layer._inbound_nodes[node_index]  # pylint: disable=protected-access
1049
1050  # Don't repeat work for shared subgraphs
1051  if node in finished_nodes:
1052    return
1053
1054  # Prevent cycles.
1055  if node in nodes_in_progress:
1056    raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name +
1057                     '" is part of a cycle.')
1058
1059  # Store the traversal order for layer sorting.
1060  if layer not in layer_indices:
1061    layer_indices[layer] = len(layer_indices)
1062
1063  # Propagate to all previous tensors connected to this node.
1064  nodes_in_progress.add(node)
1065  if not node.is_input:
1066    for tensor in node.keras_inputs:
1067      _build_map_helper(tensor, finished_nodes, nodes_in_progress,
1068                        nodes_in_decreasing_depth, layer_indices)
1069
1070  finished_nodes.add(node)
1071  nodes_in_progress.remove(node)
1072  nodes_in_decreasing_depth.append(node)
1073
1074
1075def _map_subgraph_network(inputs, outputs):
1076  """Returns the nodes and layers in the topology from `inputs` to `outputs`.
1077
1078  Args:
1079    inputs: List of input tensors.
1080    outputs: List of output tensors.
1081
1082  Returns:
1083    A tuple of List{Node] and List[Layer].
1084  """
1085  if not keras_tensor.keras_tensors_enabled():
1086    base_layer_utils.create_keras_history(outputs)
1087  # Keep only nodes and layers in the topology between inputs and outputs.
1088  _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
1089  return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
1090
1091
1092def _should_skip_first_node(layer):
1093  """Returns True if the first layer node should not be saved or loaded."""
1094  # Networks that are constructed with an Input layer/shape start with a
1095  # pre-existing node linking their input to output. This node is excluded from
1096  # the network config.
1097  if layer._self_tracked_trackables:
1098    return (isinstance(layer, Functional) and
1099            # Filter out Sequential models without an input shape.
1100            isinstance(layer._self_tracked_trackables[0],
1101                       input_layer_module.InputLayer))
1102  else:
1103    return isinstance(layer, Functional)
1104
1105
1106def connect_ancillary_layers(model, created_layers):
1107  """Adds layers that are not connected to the outputs to the model."""
1108  # Layers not connected to outputs, such as those added in `add_loss`.
1109  ancillary_layers = [
1110      layer for layer in created_layers.values() if layer not in model.layers
1111  ]
1112  if ancillary_layers:
1113    relevant_nodes = nest.flatten([
1114        layer.inbound_nodes[1:]
1115        if _should_skip_first_node(layer) else layer.inbound_nodes
1116        for layer in created_layers.values()
1117    ])
1118    model._insert_layers(ancillary_layers, relevant_nodes)
1119  return model
1120
1121
1122def reconstruct_from_config(config, custom_objects=None, created_layers=None):
1123  """Reconstructs graph from config object.
1124
1125  Args:
1126    config: Dictionary returned from Network.get_config()
1127    custom_objects: Optional dictionary mapping names (strings) to custom
1128      classes or functions to be considered during deserialization.
1129    created_layers: Optional dictionary mapping names to Layer objects. Any
1130      layer not in this dictionary will be created and added to the dict.
1131      This function will add new nodes to all layers (excluding InputLayers),
1132      instead of re-using pre-existing nodes in the layers.
1133
1134  Returns:
1135    Tuple of (input tensors, output tensors, dictionary of created layers)
1136  """
1137  # Layer instances created during the graph reconstruction process.
1138  created_layers = created_layers or collections.OrderedDict()
1139
1140  # Maps input data (tuple of inbound layer name, node index) from the config
1141  # to node indices in the newly generated model. The node indices may be
1142  # different if the layers have already been called previously.
1143  node_index_map = {}
1144  node_count_by_layer = {}
1145
1146  # Dictionary mapping layer instances to
1147  # node data that specifies a layer call.
1148  # It acts as a queue that maintains any unprocessed
1149  # layer call until it becomes possible to process it
1150  # (i.e. until the input tensors to the call all exist).
1151  unprocessed_nodes = {}
1152
1153  def add_unprocessed_node(layer, node_data):
1154    if layer not in unprocessed_nodes:
1155      unprocessed_nodes[layer] = [node_data]
1156    else:
1157      unprocessed_nodes[layer].append(node_data)
1158
1159  def get_node_index(layer, config_node_index):
1160    """Returns node index in layer (might differ from config_node_index)."""
1161    if isinstance(layer, input_layer_module.InputLayer):
1162      return 0
1163    return node_index_map.get((layer.name, config_node_index), None)
1164
1165  def _deserialize_keras_tensors(kwargs, layer_map):
1166    """Deserializes Keras Tensors passed to `call`.."""
1167
1168    def _deserialize_keras_tensor(t):
1169      """Deserializes a single Keras Tensor passed to `call`."""
1170      if isinstance(t, tf_utils.ListWrapper):
1171        t = t.as_list()
1172        layer_name = t[0]
1173        node_index = t[1]
1174        tensor_index = t[2]
1175
1176        layer = layer_map[layer_name]
1177        new_node_index = get_node_index(layer, node_index)
1178        if new_node_index is None:
1179          # The inbound node may not have been processed yet,
1180          # (This can happen e.g. if it depends on a different set
1181          # of inputs than those that have been processed already).
1182          # raise an IndexError so that the current node puts itself
1183          # back on the unprocessed queue.
1184          # Caution: This may lead to infinite loops for malformed
1185          # network configurations! (or when there is a bug in
1186          # the network config loading code).
1187          raise IndexError
1188        node = layer._inbound_nodes[new_node_index]
1189        return nest.flatten(node.outputs)[tensor_index]
1190      return t
1191
1192    kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
1193    return nest.map_structure(_deserialize_keras_tensor, kwargs)
1194
1195  def process_node(layer, node_data):
1196    """Deserialize a node.
1197
1198    Args:
1199        layer: layer instance.
1200        node_data: Nested structure of `ListWrapper`.
1201
1202    Raises:
1203        ValueError: In case of improperly formatted `node_data`.
1204    """
1205    input_tensors = []
1206    for input_data in nest.flatten(node_data):
1207      input_data = input_data.as_list()
1208      inbound_layer_name = input_data[0]
1209      inbound_node_index = input_data[1]
1210      inbound_tensor_index = input_data[2]
1211      if len(input_data) == 3:
1212        kwargs = {}
1213      elif len(input_data) == 4:
1214        kwargs = input_data[3]
1215        try:
1216          kwargs = _deserialize_keras_tensors(kwargs, created_layers)
1217        except IndexError:
1218          # Happens if keras tensors in kwargs are still unprocessed
1219          add_unprocessed_node(layer, node_data)
1220          return
1221      else:
1222        raise ValueError('Improperly formatted model config.')
1223
1224      if inbound_layer_name != node_module._CONSTANT_VALUE:
1225        inbound_layer = created_layers[inbound_layer_name]
1226        inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
1227
1228        if inbound_node_index is None:
1229          add_unprocessed_node(layer, node_data)
1230          return
1231        inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1232        input_tensors.append(
1233            nest.flatten(inbound_node.outputs)[inbound_tensor_index])
1234      else:
1235        # We received a constant w/ no Keras history attached
1236        input_tensors.append(inbound_tensor_index)
1237    input_tensors = nest.pack_sequence_as(node_data, input_tensors)
1238    # Call layer on its inputs, thus creating the node
1239    # and building the layer if needed.
1240    if input_tensors is not None:
1241      if not layer._preserve_input_structure_in_config:
1242        input_tensors = (
1243            base_layer_utils.unnest_if_single_tensor(input_tensors))
1244      output_tensors = layer(input_tensors, **kwargs)
1245
1246      # Update node index map.
1247      output_index = nest.flatten(output_tensors)[0]._keras_history.node_index
1248      node_index_map[(layer.name, node_count_by_layer[layer])] = output_index
1249      node_count_by_layer[layer] += 1
1250
1251  def process_layer(layer_data):
1252    """Deserializes a layer, then call it on appropriate inputs.
1253
1254    Args:
1255        layer_data: layer config dict.
1256
1257    Raises:
1258        ValueError: In case of improperly formatted `layer_data` dict.
1259    """
1260    layer_name = layer_data['name']
1261
1262    if layer_name in created_layers:
1263      layer = created_layers[layer_name]
1264    else:
1265      # Instantiate layer.
1266      from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
1267
1268      layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1269      created_layers[layer_name] = layer
1270
1271    node_count_by_layer[layer] = int(_should_skip_first_node(layer))
1272
1273    # Gather layer inputs and convert to `ListWrapper` objects.
1274    inbound_nodes_data = layer_data['inbound_nodes']
1275    inbound_nodes_data = tf_utils.convert_inner_node_data(
1276        inbound_nodes_data, wrap=True)
1277    for node_data in inbound_nodes_data:
1278      # We don't process nodes (i.e. make layer calls)
1279      # on the fly because the inbound node may not yet exist,
1280      # in case of layer shared at different topological depths
1281      # (e.g. a model such as A(B(A(B(x)))))
1282      add_unprocessed_node(layer, node_data)
1283
1284  # First, we create all layers and enqueue nodes to be processed
1285  for layer_data in config['layers']:
1286    process_layer(layer_data)
1287  # Then we process nodes in order of layer depth.
1288  # Nodes that cannot yet be processed (if the inbound node
1289  # does not yet exist) are re-enqueued, and the process
1290  # is repeated until all nodes are processed.
1291  while unprocessed_nodes:
1292    for layer_data in config['layers']:
1293      layer = created_layers[layer_data['name']]
1294      if layer in unprocessed_nodes:
1295        for node_data in unprocessed_nodes.pop(layer):
1296          process_node(layer, node_data)
1297
1298  input_tensors = []
1299  output_tensors = []
1300
1301  input_layers = tf_utils.convert_inner_node_data(
1302      config['input_layers'], wrap=True)
1303  for layer_data in nest.flatten(input_layers):
1304    layer_name, node_index, tensor_index = layer_data.as_list()
1305    assert layer_name in created_layers
1306    layer = created_layers[layer_name]
1307    node_index = get_node_index(layer, node_index)
1308    layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1309    input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1310
1311  output_layers = tf_utils.convert_inner_node_data(
1312      config['output_layers'], wrap=True)
1313  for layer_data in nest.flatten(output_layers):
1314    layer_name, node_index, tensor_index = layer_data.as_list()
1315    assert layer_name in created_layers
1316    layer = created_layers[layer_name]
1317    node_index = get_node_index(layer, node_index)
1318    layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1319    output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1320
1321  input_tensors = nest.pack_sequence_as(input_layers, input_tensors)
1322  output_tensors = nest.pack_sequence_as(output_layers, output_tensors)
1323  return input_tensors, output_tensors, created_layers
1324
1325
1326def get_network_config(network, serialize_layer_fn=None):
1327  """Builds the config, which consists of the node graph and serialized layers.
1328
1329  Args:
1330    network: A Network object.
1331    serialize_layer_fn: Function used to serialize layers.
1332
1333  Returns:
1334    Config dictionary.
1335  """
1336  serialize_layer_fn = (
1337      serialize_layer_fn or generic_utils.serialize_keras_object)
1338  config = {
1339      'name': network.name,
1340  }
1341  node_conversion_map = {}
1342  for layer in network.layers:
1343    kept_nodes = 1 if _should_skip_first_node(layer) else 0
1344    for original_node_index, node in enumerate(layer._inbound_nodes):
1345      node_key = _make_node_key(layer.name, original_node_index)
1346      if node_key in network._network_nodes:
1347        node_conversion_map[node_key] = kept_nodes
1348        kept_nodes += 1
1349  layer_configs = []
1350
1351  with generic_utils.SharedObjectSavingScope():
1352    for layer in network.layers:  # From the earliest layers on.
1353      filtered_inbound_nodes = []
1354      for original_node_index, node in enumerate(layer._inbound_nodes):
1355        node_key = _make_node_key(layer.name, original_node_index)
1356        if node_key in network._network_nodes and not node.is_input:
1357          # The node is relevant to the model:
1358          # add to filtered_inbound_nodes.
1359          node_data = node.serialize(_make_node_key, node_conversion_map)
1360          filtered_inbound_nodes.append(node_data)
1361
1362      layer_config = serialize_layer_fn(layer)
1363      layer_config['name'] = layer.name
1364      layer_config['inbound_nodes'] = filtered_inbound_nodes
1365      layer_configs.append(layer_config)
1366    config['layers'] = layer_configs
1367
1368  # Gather info about inputs and outputs.
1369  model_inputs = []
1370  for i in range(len(network._input_layers)):
1371    layer, node_index, tensor_index = network._input_coordinates[i]
1372    node_key = _make_node_key(layer.name, node_index)
1373    if node_key not in network._network_nodes:
1374      continue
1375    new_node_index = node_conversion_map[node_key]
1376    model_inputs.append(
1377        tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1378  model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs)
1379  # Preserve external Keras compat for Models with single input.
1380  if not nest.is_nested(model_inputs):
1381    model_inputs = [model_inputs]
1382  model_inputs = tf_utils.convert_inner_node_data(model_inputs)
1383  config['input_layers'] = model_inputs
1384
1385  model_outputs = []
1386  for i in range(len(network._output_layers)):
1387    layer, node_index, tensor_index = network._output_coordinates[i]
1388    node_key = _make_node_key(layer.name, node_index)
1389    if node_key not in network._network_nodes:
1390      continue
1391    new_node_index = node_conversion_map[node_key]
1392    model_outputs.append(
1393        tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1394  model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs)
1395  # Preserve external Keras compat for Models with single output.
1396  if not nest.is_nested(model_outputs):
1397    model_outputs = [model_outputs]
1398  model_outputs = tf_utils.convert_inner_node_data(model_outputs)
1399  config['output_layers'] = model_outputs
1400  return config
1401
1402
1403def shape_with_no_batch_size(x):
1404  if x.shape.rank is None:
1405    return None
1406  shape = x.shape.as_list()
1407  if shape:
1408    shape[0] = None
1409  return shape
1410
1411
1412class ModuleWrapper(base_layer.Layer):
1413  """Wrapper for `tf.Module`s to support the Functional and Sequential API."""
1414
1415  def __init__(self, module, method_name=None, **kwargs):
1416    """Initializes the wrapper Layer for this module.
1417
1418    Args:
1419      module: The `tf.Module` instance to be wrapped.
1420      method_name: (Optional) str. The name of the method to use as the forward
1421        pass of the module. If not set, defaults to '__call__' if defined, or
1422        'call'.
1423      **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
1424
1425    Raises:
1426      ValueError: If `method` is not defined on `module`.
1427    """
1428    super(ModuleWrapper, self).__init__(**kwargs)
1429    if method_name is None:
1430      if hasattr(module, '__call__'):
1431        method_name = '__call__'
1432      elif hasattr(module, 'call'):
1433        method_name = 'call'
1434    if method_name is None or not hasattr(module, method_name):
1435      raise ValueError('{} is not defined on object {}'.format(
1436          method_name, module))
1437
1438    self._module = module
1439    self._method_name = method_name
1440
1441    # Check if module.__call__ has a `training` arg or accepts `**kwargs`.
1442    method = getattr(module, method_name)
1443    method_arg_spec = tf_inspect.getfullargspec(method)
1444    self._expects_training_arg = ('training' in method_arg_spec.args or
1445                                  method_arg_spec.varkw is not None)
1446    self._expects_mask_arg = ('mask' in method_arg_spec.args or
1447                              method_arg_spec.varkw is not None)
1448
1449  def call(self, *args, **kwargs):
1450    if 'training' in kwargs and not self._expects_training_arg:
1451      kwargs.pop('training')
1452    if 'mask' in kwargs and not self._expects_mask_arg:
1453      kwargs.pop('mask')
1454    return getattr(self._module, self._method_name)(*args, **kwargs)
1455