• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the base Layer class, from which all layers inherit."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import itertools
24import threading
25import warnings
26
27import numpy as np
28import six
29from six.moves import zip  # pylint: disable=redefined-builtin
30
31from tensorflow.python.autograph.core import ag_ctx
32from tensorflow.python.autograph.impl import api as autograph
33from tensorflow.python.distribute import distribution_strategy_context as ds_context
34from tensorflow.python.eager import context
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import func_graph
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import tensor_util
42from tensorflow.python.keras import backend
43from tensorflow.python.keras import constraints
44from tensorflow.python.keras import initializers
45from tensorflow.python.keras import regularizers
46from tensorflow.python.keras.engine import base_layer
47from tensorflow.python.keras.engine import base_layer_utils
48from tensorflow.python.keras.engine import input_spec
49from tensorflow.python.keras.mixed_precision import autocast_variable
50from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
51from tensorflow.python.keras.mixed_precision import policy
52from tensorflow.python.keras.saving.saved_model import layer_serialization
53from tensorflow.python.keras.utils import generic_utils
54from tensorflow.python.keras.utils import layer_utils
55from tensorflow.python.keras.utils import object_identity
56from tensorflow.python.keras.utils import tf_inspect
57from tensorflow.python.keras.utils import tf_utils
58# A module that only depends on `keras.layers` import these from here.
59from tensorflow.python.keras.utils.generic_utils import to_snake_case  # pylint: disable=unused-import
60from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list  # pylint: disable=unused-import
61from tensorflow.python.module import module
62from tensorflow.python.ops import array_ops
63from tensorflow.python.ops import math_ops
64from tensorflow.python.ops import variables as tf_variables
65from tensorflow.python.ops.ragged import ragged_tensor
66from tensorflow.python.platform import tf_logging
67from tensorflow.python.training.tracking import base as trackable
68from tensorflow.python.training.tracking import data_structures
69from tensorflow.python.training.tracking import tracking
70from tensorflow.python.util import nest
71from tensorflow.tools.docs import doc_controls
72
73
74# pylint: disable=g-classes-have-attributes
75class Layer(base_layer.Layer):
76  """Base layer class.
77
78  This is the class from which all layers inherit.
79
80  A layer is a class implementing common neural networks operations, such
81  as convolution, batch norm, etc. These operations require managing weights,
82  losses, updates, and inter-layer connectivity.
83
84  Users will just instantiate a layer and then treat it as a callable.
85
86  We recommend that descendants of `Layer` implement the following methods:
87
88  * `__init__()`: Save configuration in member variables
89  * `build()`: Called once from `__call__`, when we know the shapes of inputs
90    and `dtype`. Should have the calls to `add_weight()`, and then
91    call the super's `build()` (which sets `self.built = True`, which is
92    nice in case the user wants to call `build()` manually before the
93    first `__call__`).
94  * `call()`: Called in `__call__` after making sure `build()` has been called
95    once. Should actually perform the logic of applying the layer to the
96    input tensors (which should be passed in as the first argument).
97
98  Args:
99    trainable: Boolean, whether the layer's variables should be trainable.
100    name: String name of the layer.
101    dtype: The dtype of the layer's computations and weights (default of
102      `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type
103      of the first input in TensorFlow 1).
104    dynamic: Set this to `True` if your layer should only be run eagerly, and
105      should not be used to generate a static computation graph.
106      This would be the case for a Tree-RNN or a recursive network,
107      for example, or generally for any layer that manipulates tensors
108      using Python control flow. If `False`, we assume that the layer can
109      safely be used to generate a static computation graph.
110
111  Attributes:
112    name: The name of the layer (string).
113    dtype: The dtype of the layer's computations and weights. If mixed
114      precision is used with a `tf.keras.mixed_precision.Policy`, this is
115      instead just the dtype of the layer's weights, as the computations are
116      done in a different dtype.
117    updates: List of update ops of this layer.
118    losses: List of losses added by this layer.
119    trainable_weights: List of variables to be included in backprop.
120    non_trainable_weights: List of variables that should not be
121      included in backprop.
122    weights: The concatenation of the lists trainable_weights and
123      non_trainable_weights (in this order).
124    trainable: Whether the layer should be trained (boolean).
125    input_spec: Optional (list of) `InputSpec` object(s) specifying the
126      constraints on inputs that can be accepted by the layer.
127
128  Each layer has a dtype, which is typically the dtype of the layer's
129  computations and variables. A layer's dtype can be queried via the
130  `Layer.dtype` property. The dtype is specified with the `dtype` constructor
131  argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()`
132  if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
133  layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed
134  precision is used, layers may have different computation and variable dtypes.
135  See `tf.keras.mixed_precision.Policy` for details on layer dtypes.
136  """
137
138  # See tf.Module for the usage of this property.
139  # The key for _obj_reference_counts_dict is a Trackable, which could be a
140  # variable or layer etc. tf.Module._flatten will fail to flatten the key
141  # since it is trying to convert Trackable to a string. This attribute can be
142  # ignored even after the fix of nest lib, since the trackable object should
143  # already been available as individual attributes. _obj_reference_counts_dict
144  # just contains a copy of them.
145  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
146      ('_obj_reference_counts_dict',),
147      module.Module._TF_MODULE_IGNORED_PROPERTIES
148  ))
149
150  @trackable.no_automatic_dependency_tracking
151  def __init__(self, trainable=True, name=None, dtype=None, dynamic=False,
152               **kwargs):
153    self._instrument_layer_creation()
154
155    # These properties should be set by the user via keyword arguments.
156    # note that 'dtype', 'input_shape' and 'batch_input_shape'
157    # are only applicable to input layers: do not pass these keywords
158    # to non-input layers.
159    allowed_kwargs = {
160        'input_dim', 'input_shape', 'batch_input_shape', 'batch_size',
161        'weights', 'activity_regularizer', 'autocast', 'implementation'
162    }
163    # Validate optional keyword arguments.
164    generic_utils.validate_kwargs(kwargs, allowed_kwargs)
165
166    # Mutable properties
167    # Indicates whether the layer's weights are updated during training
168    # and whether the layer's updates are run during training.
169    self._trainable = trainable
170    # A stateful layer is a layer whose updates are run during inference too,
171    # for instance stateful RNNs.
172    self._stateful = False
173    # Indicates whether `build` needs to be called upon layer call, to create
174    # the layer's weights.
175    self.built = False
176    self._build_input_shape = None
177    # Provides information about which inputs are compatible with the layer.
178    self._input_spec = None
179    self.supports_masking = False
180
181    self._init_set_name(name)
182    self._activity_regularizer = regularizers.get(
183        kwargs.pop('activity_regularizer', None))
184    self._maybe_create_attribute('_trainable_weights', [])
185    self._maybe_create_attribute('_non_trainable_weights', [])
186    self._updates = []
187    # Object to store all thread local layer properties.
188    self._thread_local = threading.local()
189    # A list of zero-argument lambdas which return Tensors, used for variable
190    # regularizers.
191    self._callable_losses = []
192    # A list of symbolic Tensors containing activity regularizers and losses
193    # manually added through `add_loss` in graph-building mode.
194    self._losses = []
195    # A list of metric instances corresponding to the symbolic metric tensors
196    # added using the `add_metric` API.
197    self._metrics = []
198
199    # Both graph and subclassed networks have a dtype policy. For graph
200    # networks, the policy's compute and variable dtypes are ignored. Such
201    # networks only use the policy if it is a PolicyV1, in which case it uses
202    # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
203    # subclassed networks, the compute and variable dtypes are used as like any
204    # ordinary layer.
205    self._set_dtype_policy(dtype)
206    # Boolean indicating whether the layer automatically casts its inputs to the
207    # layer's compute_dtype.
208    self._autocast = kwargs.get('autocast',
209                                base_layer_utils.v2_dtype_behavior_enabled())
210
211    # Dependencies tracked via attribute assignment.
212    # All layers in order of horizontal graph traversal.
213    # Entries are unique. For models includes input and output layers.
214    self._maybe_create_attribute('_self_tracked_trackables', [])
215
216    # These lists will be filled via successive calls
217    # to self._add_inbound_node().
218    # Used in symbolic mode only, only in conjunction with graph-networks
219    self._inbound_nodes_value = []
220    self._outbound_nodes_value = []
221
222    self._init_call_fn_args()
223
224    # Whether the `call` method can be used to build a TF graph without issues.
225    # This attribute has no effect if the model is created using the Functional
226    # API. Instead, `model.dynamic` is determined based on the internal layers.
227    self._dynamic = dynamic
228
229    # Manage input shape information if passed.
230    if 'input_dim' in kwargs and 'input_shape' not in kwargs:
231      # Backwards compatibility: alias 'input_dim' to 'input_shape'.
232      kwargs['input_shape'] = (kwargs['input_dim'],)
233    if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
234      # In this case we will later create an input layer
235      # to insert before the current layer
236      if 'batch_input_shape' in kwargs:
237        batch_input_shape = tuple(kwargs['batch_input_shape'])
238      elif 'input_shape' in kwargs:
239        if 'batch_size' in kwargs:
240          batch_size = kwargs['batch_size']
241        else:
242          batch_size = None
243        batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
244      self._batch_input_shape = batch_input_shape
245
246    # Manage initial weight values if passed.
247    self._initial_weights = kwargs.get('weights', None)
248
249    # Whether the layer will track any layers that is set as attribute on itself
250    # as sub-layers, the weights from the sub-layers will be included in the
251    # parent layer's variables() as well.
252    # Default to True, which means auto tracking is turned on. Certain subclass
253    # might want to turn it off, like Sequential model.
254    self._auto_track_sub_layers = True
255
256    # Mark this layer as having been originally built as a tf1 layer/model
257    self._originally_built_as_v1 = True
258
259    # For backwards compat reasons, most built-in layers do not guarantee
260    # That they will 100% preserve the structure of input args when saving
261    # / loading configs. E.g. they may un-nest an arg that is
262    # a list with one element.
263    self._preserve_input_structure_in_config = False
264
265  @trackable.no_automatic_dependency_tracking
266  @generic_utils.default
267  def build(self, input_shape):
268    """Creates the variables of the layer (optional, for subclass implementers).
269
270    This is a method that implementers of subclasses of `Layer` or `Model`
271    can override if they need a state-creation step in-between
272    layer instantiation and layer call.
273
274    This is typically used to create the weights of `Layer` subclasses.
275
276    Args:
277      input_shape: Instance of `TensorShape`, or list of instances of
278        `TensorShape` if the layer expects a list of inputs
279        (one instance per input).
280    """
281    if not hasattr(self.build, '_is_default'):
282      self._build_input_shape = input_shape
283    self.built = True
284
285  @doc_controls.for_subclass_implementers
286  def call(self, inputs, **kwargs):  # pylint: disable=unused-argument
287    """This is where the layer's logic lives.
288
289    Args:
290        inputs: Input tensor, or list/tuple of input tensors.
291        **kwargs: Additional keyword arguments.
292
293    Returns:
294        A tensor or list/tuple of tensors.
295    """
296    return inputs
297
298  @doc_controls.for_subclass_implementers
299  def _add_trackable(self, trackable_object, trainable):
300    """Adds a Trackable object to this layer's state.
301
302    Args:
303      trackable_object: The tf.tracking.Trackable object to add.
304      trainable: Boolean, whether the variable should be part of the layer's
305        "trainable_variables" (e.g. variables, biases) or
306        "non_trainable_variables" (e.g. BatchNorm mean and variance).
307
308    Returns:
309      The TrackableWeightHandler used to track this object.
310    """
311    handler = base_layer_utils.TrackableWeightHandler(trackable_object)
312    if trainable:
313      self._trainable_weights.append(handler)
314    else:
315      self._non_trainable_weights.append(handler)
316    return handler
317
318  @doc_controls.for_subclass_implementers
319  def add_weight(self,
320                 name=None,
321                 shape=None,
322                 dtype=None,
323                 initializer=None,
324                 regularizer=None,
325                 trainable=None,
326                 constraint=None,
327                 partitioner=None,
328                 use_resource=None,
329                 synchronization=tf_variables.VariableSynchronization.AUTO,
330                 aggregation=tf_variables.VariableAggregation.NONE,
331                 **kwargs):
332    """Adds a new variable to the layer.
333
334    Args:
335      name: Variable name.
336      shape: Variable shape. Defaults to scalar if unspecified.
337      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
338      initializer: Initializer instance (callable).
339      regularizer: Regularizer instance (callable).
340      trainable: Boolean, whether the variable should be part of the layer's
341        "trainable_variables" (e.g. variables, biases)
342        or "non_trainable_variables" (e.g. BatchNorm mean and variance).
343        Note that `trainable` cannot be `True` if `synchronization`
344        is set to `ON_READ`.
345      constraint: Constraint instance (callable).
346      partitioner: Partitioner to be passed to the `Trackable` API.
347      use_resource: Whether to use `ResourceVariable`.
348      synchronization: Indicates when a distributed a variable will be
349        aggregated. Accepted values are constants defined in the class
350        `tf.VariableSynchronization`. By default the synchronization is set to
351        `AUTO` and the current `DistributionStrategy` chooses
352        when to synchronize. If `synchronization` is set to `ON_READ`,
353        `trainable` must not be set to `True`.
354      aggregation: Indicates how a distributed variable will be aggregated.
355        Accepted values are constants defined in the class
356        `tf.VariableAggregation`.
357      **kwargs: Additional keyword arguments. Accepted values are `getter`,
358        `collections`, `experimental_autocast` and `caching_device`.
359
360    Returns:
361      The created variable. Usually either a `Variable` or `ResourceVariable`
362      instance. If `partitioner` is not `None`, a `PartitionedVariable`
363      instance is returned.
364
365    Raises:
366      RuntimeError: If called with partitioned variable regularization and
367        eager execution is enabled.
368      ValueError: When giving unsupported dtype and no initializer or when
369        trainable has been set to True with synchronization set as `ON_READ`.
370    """
371    if shape is None:
372      shape = ()
373    # Validate optional keyword arguments.
374    for kwarg in kwargs:
375      if kwarg not in ['getter', 'collections', 'experimental_autocast',
376                       'caching_device']:
377        raise TypeError('Unknown keyword argument:', kwarg)
378    getter = kwargs.pop('getter', base_layer_utils.make_variable)
379    collections_arg = kwargs.pop('collections', None)
380    # 'experimental_autocast' can be set to False by the caller to indicate an
381    # AutoCastVariable should never be created.
382    autocast = kwargs.pop('experimental_autocast', True)
383    # See the docstring for tf.Variable about the details for caching_device.
384    caching_device = kwargs.pop('caching_device', None)
385
386    if dtype is None:
387      dtype = self.dtype or backend.floatx()
388    dtype = dtypes.as_dtype(dtype)
389    if self._dtype_policy.variable_dtype is None:
390      # The policy is "_infer", so we infer the policy from the variable dtype.
391      self._set_dtype_policy(policy.Policy(dtype.base_dtype.name))
392    initializer = initializers.get(initializer)
393    regularizer = regularizers.get(regularizer)
394    constraint = constraints.get(constraint)
395
396    if synchronization == tf_variables.VariableSynchronization.ON_READ:
397      if trainable:
398        raise ValueError(
399            'Synchronization value can be set to '
400            'VariableSynchronization.ON_READ only for non-trainable variables. '
401            'You have specified trainable=True and '
402            'synchronization=VariableSynchronization.ON_READ.')
403      else:
404        # Set trainable to be false when variable is to be synced on read.
405        trainable = False
406    elif trainable is None:
407      trainable = True
408
409    # Initialize variable when no initializer provided
410    if initializer is None:
411      # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
412      if dtype.is_floating:
413        initializer = initializers.get('glorot_uniform')
414      # If dtype is DT_INT/DT_UINT, provide a default value `zero`
415      # If dtype is DT_BOOL, provide a default value `FALSE`
416      elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
417        initializer = initializers.zeros()
418      # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
419      else:
420        raise ValueError('An initializer for variable %s of type %s is required'
421                         ' for layer %s' % (name, dtype.base_dtype, self.name))
422
423    if (autocast and
424        self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype
425        and dtype.is_floating):
426      # Wrap 'getter' with a version that returns an AutoCastVariable.
427      old_getter = getter
428      def getter(*args, **kwargs):  # pylint: disable=function-redefined
429        variable = old_getter(*args, **kwargs)
430        return autocast_variable.create_autocast_variable(variable)
431      # Also the caching_device does not work with the mixed precision API,
432      # disable it if it is specified.
433      # TODO(b/142020079): Reenable it once the bug is fixed.
434      if caching_device is not None:
435        tf_logging.warn('`caching_device` does not work with mixed precision '
436                        'API. Ignoring user specified `caching_device`.')
437        caching_device = None
438
439    variable = self._add_variable_with_custom_getter(
440        name=name,
441        shape=shape,
442        # TODO(allenl): a `make_variable` equivalent should be added as a
443        # `Trackable` method.
444        getter=getter,
445        # Manage errors in Layer rather than Trackable.
446        overwrite=True,
447        initializer=initializer,
448        dtype=dtype,
449        constraint=constraint,
450        trainable=trainable,
451        partitioner=partitioner,
452        use_resource=use_resource,
453        collections=collections_arg,
454        synchronization=synchronization,
455        aggregation=aggregation,
456        caching_device=caching_device)
457    if regularizer is not None:
458      # TODO(fchollet): in the future, this should be handled at the
459      # level of variable creation, and weight regularization losses
460      # should be variable attributes.
461      name_in_scope = variable.name[:variable.name.find(':')]
462      self._handle_weight_regularization(name_in_scope,
463                                         variable,
464                                         regularizer)
465    if base_layer_utils.is_split_variable(variable):
466      for v in variable:
467        backend.track_variable(v)
468        if trainable:
469          self._trainable_weights.append(v)
470        else:
471          self._non_trainable_weights.append(v)
472    else:
473      backend.track_variable(variable)
474      if trainable:
475        self._trainable_weights.append(variable)
476      else:
477        self._non_trainable_weights.append(variable)
478    return variable
479
480  @generic_utils.default
481  def get_config(self):
482    """Returns the config of the layer.
483
484    A layer config is a Python dictionary (serializable)
485    containing the configuration of a layer.
486    The same layer can be reinstantiated later
487    (without its trained weights) from this configuration.
488
489    The config of a layer does not include connectivity
490    information, nor the layer class name. These are handled
491    by `Network` (one layer of abstraction above).
492
493    Returns:
494        Python dictionary.
495    """
496    all_args = tf_inspect.getfullargspec(self.__init__).args
497    config = {'name': self.name, 'trainable': self.trainable}
498    if hasattr(self, '_batch_input_shape'):
499      config['batch_input_shape'] = self._batch_input_shape
500    config['dtype'] = policy.serialize(self._dtype_policy)
501    if hasattr(self, 'dynamic'):
502      # Only include `dynamic` in the `config` if it is `True`
503      if self.dynamic:
504        config['dynamic'] = self.dynamic
505      elif 'dynamic' in all_args:
506        all_args.remove('dynamic')
507    expected_args = config.keys()
508    # Finds all arguments in the `__init__` that are not in the config:
509    extra_args = [arg for arg in all_args if arg not in expected_args]
510    # Check that either the only argument in the `__init__` is  `self`,
511    # or that `get_config` has been overridden:
512    if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
513      raise NotImplementedError('Layers with arguments in `__init__` must '
514                                'override `get_config`.')
515    return config
516
517  @classmethod
518  def from_config(cls, config):
519    """Creates a layer from its config.
520
521    This method is the reverse of `get_config`,
522    capable of instantiating the same layer from the config
523    dictionary. It does not handle layer connectivity
524    (handled by Network), nor weights (handled by `set_weights`).
525
526    Args:
527        config: A Python dictionary, typically the
528            output of get_config.
529
530    Returns:
531        A layer instance.
532    """
533    return cls(**config)
534
535  def compute_output_shape(self, input_shape):
536    """Computes the output shape of the layer.
537
538    If the layer has not been built, this method will call `build` on the
539    layer. This assumes that the layer will later be used with inputs that
540    match the input shape provided here.
541
542    Args:
543        input_shape: Shape tuple (tuple of integers)
544            or list of shape tuples (one per output tensor of the layer).
545            Shape tuples can include None for free dimensions,
546            instead of an integer.
547
548    Returns:
549        An input shape tuple.
550    """
551    if context.executing_eagerly():
552      # In this case we build the model first in order to do shape inference.
553      # This is acceptable because the framework only calls
554      # `compute_output_shape` on shape values that the layer would later be
555      # built for. It would however cause issues in case a user attempts to
556      # use `compute_output_shape` manually with shapes that are incompatible
557      # with the shape the Layer will be called on (these users will have to
558      # implement `compute_output_shape` themselves).
559      self._maybe_build(input_shape)
560      with ops.get_default_graph().as_default():
561        graph = func_graph.FuncGraph('graph')
562        with graph.as_default():
563          input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
564          inputs = nest.map_structure(
565              base_layer_utils.generate_placeholders_from_shape, input_shape)
566          try:
567            outputs = self(inputs, training=False)
568          except TypeError as e:
569            six.raise_from(
570                NotImplementedError(
571                    'We could not automatically infer the static shape of the '
572                    'layer\'s output. Please implement the '
573                    '`compute_output_shape` method on your layer (%s).' %
574                    self.__class__.__name__), e)
575      return nest.map_structure(lambda t: t.shape, outputs)
576    raise NotImplementedError
577
578  @doc_controls.for_subclass_implementers
579  def compute_output_signature(self, input_signature):
580    """Compute the output tensor signature of the layer based on the inputs.
581
582    Unlike a TensorShape object, a TensorSpec object contains both shape
583    and dtype information for a tensor. This method allows layers to provide
584    output dtype information if it is different from the input dtype.
585    For any layer that doesn't implement this function,
586    the framework will fall back to use `compute_output_shape`, and will
587    assume that the output dtype matches the input dtype.
588
589    Args:
590      input_signature: Single TensorSpec or nested structure of TensorSpec
591        objects, describing a candidate input for the layer.
592
593    Returns:
594      Single TensorSpec or nested structure of TensorSpec objects, describing
595        how the layer would transform the provided input.
596
597    Raises:
598      TypeError: If input_signature contains a non-TensorSpec object.
599    """
600    def check_type_return_shape(s):
601      if not isinstance(s, tensor_spec.TensorSpec):
602        raise TypeError('Only TensorSpec signature types are supported, '
603                        'but saw signature entry: {}.'.format(s))
604      return s.shape
605    input_shape = nest.map_structure(check_type_return_shape, input_signature)
606    output_shape = self.compute_output_shape(input_shape)
607    dtype = self._compute_dtype
608    if dtype is None:
609      input_dtypes = [s.dtype for s in nest.flatten(input_signature)]
610      # Default behavior when self.dtype is None, is to use the first input's
611      # dtype.
612      dtype = input_dtypes[0]
613    return nest.map_structure(
614        lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s),
615        output_shape)
616
617  @generic_utils.default
618  def compute_mask(self, inputs, mask=None):  # pylint: disable=unused-argument
619    """Computes an output mask tensor.
620
621    Args:
622        inputs: Tensor or list of tensors.
623        mask: Tensor or list of tensors.
624
625    Returns:
626        None or a tensor (or list of tensors,
627            one per output tensor of the layer).
628    """
629    if not self.supports_masking:
630      if any(m is not None for m in nest.flatten(mask)):
631        raise TypeError('Layer ' + self.name + ' does not support masking, '
632                        'but was passed an input_mask: ' + str(mask))
633      # masking not explicitly supported: return None as mask.
634      return None
635    # if masking is explicitly supported, by default
636    # carry over the input mask
637    return mask
638
639  def __call__(self, *args, **kwargs):
640    """Wraps `call`, applying pre- and post-processing steps.
641
642    Args:
643      *args: Positional arguments to be passed to `self.call`.
644      **kwargs: Keyword arguments to be passed to `self.call`.
645
646    Returns:
647      Output tensor(s).
648
649    Note:
650      - The following optional keyword arguments are reserved for specific uses:
651        * `training`: Boolean scalar tensor of Python boolean indicating
652          whether the `call` is meant for training or inference.
653        * `mask`: Boolean input mask.
654      - If the layer's `call` method takes a `mask` argument (as some Keras
655        layers do), its default value will be set to the mask generated
656        for `inputs` by the previous layer (if `input` did come from
657        a layer that generated a corresponding mask, i.e. if it came from
658        a Keras layer with masking support.
659
660    Raises:
661      ValueError: if the layer's `call` method returns None (an invalid value).
662      RuntimeError: if `super().__init__()` was not called in the constructor.
663    """
664    self._assert_built_as_v1()
665
666    if not hasattr(self, '_thread_local'):
667      raise RuntimeError(
668          'You must call `super().__init__()` in the layer constructor.')
669
670    # Grab the first positional or keyword argument.
671    if args:
672      inputs = args[0]
673      args = args[1:]
674    elif self._call_fn_args[0] in kwargs:
675      inputs = kwargs.pop(self._call_fn_args[0])
676    else:
677      raise ValueError(
678          'The first argument to `Layer.call` must always be passed.')
679
680    call_context = base_layer_utils.call_context()
681    input_list = nest.flatten(inputs)
682
683    # We will attempt to build a TF graph if & only if all inputs are symbolic.
684    # This is always the case in graph mode. It can also be the case in eager
685    # mode when all inputs can be traced back to `keras.Input()` (when building
686    # models using the functional API).
687    build_graph = tf_utils.are_all_symbolic_tensors(input_list)
688
689    # Accept NumPy and scalar inputs by converting to Tensors.
690    if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
691      def _convert_non_tensor(x):
692        # Don't call `ops.convert_to_tensor` on all `inputs` because
693        # `SparseTensors` can't be converted to `Tensor`.
694        if isinstance(x, (np.ndarray, float, int)):
695          return ops.convert_to_tensor_v2_with_dispatch(x)
696        return x
697      inputs = nest.map_structure(_convert_non_tensor, inputs)
698      input_list = nest.flatten(inputs)
699
700    # Handle `mask` propagation from previous layer to current layer. Masks can
701    # be propagated explicitly via the `mask` argument, or implicitly via
702    # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed
703    # explicitly take priority.
704    mask_arg_passed_by_framework = False
705    input_masks = self._collect_input_masks(inputs, args, kwargs)
706    if (self._expects_mask_arg and input_masks is not None and
707        not self._call_arg_was_passed('mask', args, kwargs)):
708      mask_arg_passed_by_framework = True
709      kwargs['mask'] = input_masks
710
711    # If `training` argument is None or not explicitly passed,
712    # propagate `training` value from this layer's calling layer.
713    training_value = None
714    training_arg_passed_by_framework = False
715    # Priority 1: `training` was explicitly passed.
716    if self._call_arg_was_passed('training', args, kwargs):
717      training_value = self._get_call_arg_value('training', args, kwargs)
718      if not self._expects_training_arg:
719        kwargs.pop('training')
720
721    if training_value is None:
722      # Priority 2: `training` was passed to a parent layer.
723      if call_context.training is not None:
724        training_value = call_context.training
725      # Priority 3a: `learning_phase()` has been set.
726      elif backend.global_learning_phase_is_set():
727        training_value = backend.learning_phase()
728      # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
729      elif build_graph:
730        with backend.get_graph().as_default():
731          if base_layer_utils.is_in_keras_graph():
732            training_value = backend.learning_phase()
733
734      if self._expects_training_arg and training_value is not None:
735        # Force the training_value to be bool type which matches to the contract
736        # for layer/model call args.
737        if tensor_util.is_tf_type(training_value):
738          training_value = math_ops.cast(training_value, dtypes.bool)
739        else:
740          training_value = bool(training_value)
741        args, kwargs = self._set_call_arg_value(
742            'training', training_value, args, kwargs)
743        training_arg_passed_by_framework = True
744
745    # Only create Keras history if at least one tensor originates from a
746    # `keras.Input`. Otherwise this Layer may be being used outside the Keras
747    # framework.
748    if build_graph and base_layer_utils.needs_keras_history(inputs):
749      base_layer_utils.create_keras_history(inputs)
750
751    with call_context.enter(self, inputs, build_graph, training_value):
752      # Check input assumptions set after layer building, e.g. input shape.
753      if build_graph:
754        # Symbolic execution on symbolic tensors. We will attempt to build
755        # the corresponding TF subgraph inside `backend.get_graph()`
756        # TODO(reedwm): We should assert input compatibility after the inputs
757        # are casted, not before.
758        input_spec.assert_input_compatibility(self.input_spec, inputs,
759                                              self.name)
760        graph = backend.get_graph()
761        with graph.as_default(), backend.name_scope(self._name_scope()):
762          # Build layer if applicable (if the `build` method has been
763          # overridden).
764          self._maybe_build(inputs)
765          cast_inputs = self._maybe_cast_inputs(inputs)
766
767          # Wrapping `call` function in autograph to allow for dynamic control
768          # flow and control dependencies in call. We are limiting this to
769          # subclassed layers as autograph is strictly needed only for
770          # subclassed layers and models.
771          # tf_convert will respect the value of autograph setting in the
772          # enclosing tf.function, if any.
773          if (base_layer_utils.is_subclassed(self) and
774              not base_layer_utils.from_saved_model(self)):
775            call_fn = autograph.tf_convert(
776                self.call, ag_ctx.control_status_ctx())
777          else:
778            call_fn = self.call
779
780          if not self.dynamic:
781            try:
782              with autocast_variable.enable_auto_cast_variables(
783                  self._compute_dtype_object):
784                outputs = call_fn(cast_inputs, *args, **kwargs)
785
786            except errors.OperatorNotAllowedInGraphError as e:
787              raise TypeError('You are attempting to use Python control '
788                              'flow in a layer that was not declared to be '
789                              'dynamic. Pass `dynamic=True` to the class '
790                              'constructor.\nEncountered error:\n"""\n' +
791                              str(e) + '\n"""')
792          else:
793            # We will use static shape inference to return symbolic tensors
794            # matching the specifications of the layer outputs.
795            # Since `self.dynamic` is True, we will never attempt to
796            # run the underlying TF graph (which is disconnected).
797            # TODO(fchollet): consider py_func as an alternative, which
798            # would enable us to run the underlying graph if needed.
799            outputs = self._symbolic_call(inputs)
800
801          if outputs is None:
802            raise ValueError('A layer\'s `call` method should return a '
803                             'Tensor or a list of Tensors, not None '
804                             '(layer: ' + self.name + ').')
805          if base_layer_utils.have_all_keras_metadata(inputs):
806            if training_arg_passed_by_framework:
807              args, kwargs = self._set_call_arg_value(
808                  'training', None, args, kwargs, pop_kwarg_if_none=True)
809            if mask_arg_passed_by_framework:
810              kwargs.pop('mask')
811            outputs = self._set_connectivity_metadata((inputs,) + args, kwargs,
812                                                      outputs)
813          self._handle_activity_regularization(inputs, outputs)
814          self._set_mask_metadata(inputs, outputs, input_masks)
815          if hasattr(self, '_set_inputs') and not self.inputs:
816            # Subclassed network: explicitly set metadata normally set by
817            # a call to self._set_inputs().
818            # TODO(b/120997007): This should be done in Eager as well, but
819            # causes garbage collection issues because of the placeholders
820            # created on the default Keras graph.
821            self._set_inputs(inputs, outputs)
822      else:
823        # Eager execution on data tensors.
824        with backend.name_scope(self._name_scope()):
825          self._maybe_build(inputs)
826          cast_inputs = self._maybe_cast_inputs(inputs)
827          with autocast_variable.enable_auto_cast_variables(
828              self._compute_dtype_object):
829            outputs = self.call(cast_inputs, *args, **kwargs)
830          self._handle_activity_regularization(inputs, outputs)
831          self._set_mask_metadata(inputs, outputs, input_masks)
832
833    return outputs
834
835  def _assert_built_as_v1(self):
836    if not hasattr(self, '_originally_built_as_v1'):
837      raise ValueError(
838          'Your Layer or Model is in an invalid state. '
839          'This can happen for the following cases:\n '
840          '1. You might be interleaving estimator/non-estimator models or '
841          'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
842          'with models/layers created outside of it. '
843          'Converting a model to an estimator (via model_to_estimator) '
844          'invalidates all models/layers made before the conversion (even '
845          'if they were not the model converted to an estimator). '
846          'Similarly, making a layer or a model inside a '
847          'a tf.compat.v1.Graph invalidates all layers/models you previously '
848          'made outside of the graph.\n'
849          '2. You might be using a custom keras layer implementation with '
850          ' custom __init__ which didn\'t call super().__init__. '
851          ' Please check the implementation of %s and its bases.' %
852          (type(self),))
853
854  @property
855  def dtype(self):
856    return self._dtype_policy.variable_dtype
857
858  @property
859  def name(self):
860    return self._name
861
862  @property
863  def dynamic(self):
864    return any(layer._dynamic for layer in self._flatten_layers())
865
866  @property
867  @doc_controls.do_not_generate_docs
868  def stateful(self):
869    return any(layer._stateful for layer in self._flatten_layers())
870
871  @stateful.setter
872  def stateful(self, value):
873    self._stateful = value
874
875  @property
876  def trainable(self):
877    return self._trainable
878
879  @trainable.setter
880  def trainable(self, value):
881    self._trainable = value
882    for layer in getattr(self, '_self_tracked_trackables', []):
883      layer.trainable = value
884
885  @property
886  def activity_regularizer(self):
887    """Optional regularizer function for the output of this layer."""
888    return self._activity_regularizer
889
890  @activity_regularizer.setter
891  def activity_regularizer(self, regularizer):
892    """Optional regularizer function for the output of this layer."""
893    self._activity_regularizer = regularizer
894
895  @property
896  def input_spec(self):
897    return self._input_spec
898
899  @input_spec.setter
900  # Must be decorated to prevent tracking, since the input_spec can be nested
901  # InputSpec objects.
902  @trackable.no_automatic_dependency_tracking
903  def input_spec(self, value):
904    for v in nest.flatten(value):
905      if v is not None and not isinstance(v, base_layer.InputSpec):
906        raise TypeError('Layer input_spec must be an instance of InputSpec. '
907                        'Got: {}'.format(v))
908    self._input_spec = value
909
910  @property
911  def updates(self):
912    collected_updates = []
913    all_layers = self._flatten_layers()
914    with backend.get_graph().as_default():
915      for layer in all_layers:
916        if not layer.trainable and not layer.stateful:
917          continue
918        for u in layer._updates:
919          if callable(u):
920            try:
921              u = u()
922            except ValueError as e:
923              if 'InaccessibleTensorError' in type(e).__name__:
924                # For one specific case of error we try to raise
925                # a more meaningful error message about the graph if we can.
926                # This error is an internal TF symbol that is not
927                # publicly exposed, so we check the name directly rather
928                # than using a direct import.
929                base_layer_utils.check_graph_consistency(
930                    method='add_update', force_raise=True)
931              raise  # check_graph_consistency may not always raise.
932          base_layer_utils.check_graph_consistency(u, method='add_update')
933          collected_updates.append(u)
934    return collected_updates
935
936  @property
937  def losses(self):
938    """Losses which are associated with this `Layer`.
939
940    Variable regularization tensors are created when this property is accessed,
941    so it is eager safe: accessing `losses` under a `tf.GradientTape` will
942    propagate gradients back to the corresponding variables.
943
944    Returns:
945      A list of tensors.
946    """
947    collected_losses = []
948    all_layers = self._flatten_layers()
949    for layer in all_layers:
950      # If any eager losses are present, we assume the model to be part of an
951      # eager training loop (either a custom one or the one used when
952      # `run_eagerly=True`) and so we always return just the eager losses.
953      collected_losses.extend(layer._losses)
954      for regularizer in layer._callable_losses:
955        loss_tensor = regularizer()
956        if loss_tensor is not None:
957          collected_losses.append(loss_tensor)
958    return collected_losses
959
960  @doc_controls.for_subclass_implementers
961  def add_loss(self, losses, inputs=None):
962    """Add loss tensor(s), potentially dependent on layer inputs.
963
964    Some losses (for instance, activity regularization losses) may be dependent
965    on the inputs passed when calling a layer. Hence, when reusing the same
966    layer on different inputs `a` and `b`, some entries in `layer.losses` may
967    be dependent on `a` and some on `b`. This method automatically keeps track
968    of dependencies.
969
970    This method can be used inside a subclassed layer or model's `call`
971    function, in which case `losses` should be a Tensor or list of Tensors.
972
973    Example:
974
975    ```python
976    class MyLayer(tf.keras.layers.Layer):
977      def call(inputs, self):
978        self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True)
979        return inputs
980    ```
981
982    This method can also be called directly on a Functional Model during
983    construction. In this case, any loss Tensors passed to this Model must
984    be symbolic and be able to be traced back to the model's `Input`s. These
985    losses become part of the model's topology and are tracked in `get_config`.
986
987    Example:
988
989    ```python
990    inputs = tf.keras.Input(shape=(10,))
991    x = tf.keras.layers.Dense(10)(inputs)
992    outputs = tf.keras.layers.Dense(1)(x)
993    model = tf.keras.Model(inputs, outputs)
994    # Activity regularization.
995    model.add_loss(tf.abs(tf.reduce_mean(x)))
996    ```
997
998    If this is not the case for your loss (if, for example, your loss references
999    a `Variable` of one of the model's layers), you can wrap your loss in a
1000    zero-argument lambda. These losses are not tracked as part of the model's
1001    topology since they can't be serialized.
1002
1003    Example:
1004
1005    ```python
1006    inputs = tf.keras.Input(shape=(10,))
1007    x = tf.keras.layers.Dense(10)(inputs)
1008    outputs = tf.keras.layers.Dense(1)(x)
1009    model = tf.keras.Model(inputs, outputs)
1010    # Weight regularization.
1011    model.add_loss(lambda: tf.reduce_mean(x.kernel))
1012    ```
1013
1014    The `get_losses_for` method allows to retrieve the losses relevant to a
1015    specific set of inputs.
1016
1017    Args:
1018      losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
1019        may also be zero-argument callables which create a loss tensor.
1020      inputs: Ignored when executing eagerly. If anything other than None is
1021        passed, it signals the losses are conditional on some of the layer's
1022        inputs, and thus they should only be run where these inputs are
1023        available. This is the case for activity regularization losses, for
1024        instance. If `None` is passed, the losses are assumed
1025        to be unconditional, and will apply across all dataflows of the layer
1026        (e.g. weight regularization losses).
1027    """
1028    def _tag_unconditional(loss):
1029      """Process the loss and tag it by setting loss._unconditional_loss."""
1030      if callable(loss):
1031        # We run the loss without autocasting, as regularizers are often
1032        # numerically unstable in float16.
1033        with autocast_variable.enable_auto_cast_variables(None):
1034          loss = loss()
1035      if loss is None:
1036        return None  # Will be filtered out when computing the .losses property
1037      if not tensor_util.is_tf_type(loss):
1038        loss = ops.convert_to_tensor_v2_with_dispatch(
1039            loss, dtype=backend.floatx())
1040      loss._unconditional_loss = (inputs is None)  # pylint: disable=protected-access
1041      return loss
1042
1043    losses = nest.flatten(losses)
1044
1045    callable_losses = []
1046    symbolic_losses = []
1047    for loss in losses:
1048      if callable(loss):
1049        callable_losses.append(functools.partial(_tag_unconditional, loss))
1050        continue
1051      if loss is None:
1052        continue
1053      if not tensor_util.is_tf_type(loss):
1054        loss = ops.convert_to_tensor_v2_with_dispatch(
1055            loss, dtype=backend.floatx())
1056      # TF Functions should take the eager path.
1057      if (tf_utils.is_symbolic_tensor(loss) and
1058          not base_layer_utils.is_in_tf_function()):
1059        symbolic_losses.append(_tag_unconditional(loss))
1060        base_layer_utils.check_graph_consistency(loss, method='add_loss')
1061
1062    self._callable_losses.extend(callable_losses)
1063
1064    in_call_context = base_layer_utils.call_context().in_call
1065
1066    if in_call_context:
1067      for symbolic_loss in symbolic_losses:
1068        self._losses.append(symbolic_loss)
1069    else:
1070      for symbolic_loss in symbolic_losses:
1071        if getattr(self, '_is_graph_network', False):
1072          self._graph_network_add_loss(symbolic_loss)
1073        else:
1074          # Possible a loss was added in a Layer's `build`.
1075          self._losses.append(symbolic_loss)
1076
1077  @property
1078  def metrics(self):
1079    collected_metrics = []
1080    for layer in self._flatten_layers():
1081      collected_metrics.extend(layer._metrics)
1082    return collected_metrics
1083
1084  @doc_controls.for_subclass_implementers
1085  def add_metric(self, value, aggregation=None, name=None):
1086    """Adds metric tensor to the layer.
1087
1088    Args:
1089      value: Metric tensor.
1090      aggregation: Sample-wise metric reduction function. If `aggregation=None`,
1091        it indicates that the metric tensor provided has been aggregated
1092        already. eg, `bin_acc = BinaryAccuracy(name='acc')` followed by
1093        `model.add_metric(bin_acc(y_true, y_pred))`. If aggregation='mean', the
1094        given metric tensor will be sample-wise reduced using `mean` function.
1095        eg, `model.add_metric(tf.reduce_sum(outputs), name='output_mean',
1096        aggregation='mean')`.
1097      name: String metric name.
1098
1099    Raises:
1100      ValueError: If `aggregation` is anything other than None or `mean`.
1101    """
1102    if aggregation is not None and aggregation != 'mean':
1103      raise ValueError(
1104          'We currently support only `mean` sample-wise metric aggregation. '
1105          'You provided aggregation=`%s`' % aggregation)
1106
1107    from_metric_obj = hasattr(value, '_metric_obj')
1108    is_symbolic = tf_utils.is_symbolic_tensor(value)
1109    in_call_context = base_layer_utils.call_context().in_call
1110
1111    if name is None and not from_metric_obj:
1112      # Eg. `self.add_metric(math_ops.reduce_sum(x), aggregation='mean')`
1113      # In eager mode, we use metric name to lookup a metric. Without a name,
1114      # a new Mean metric wrapper will be created on every model/layer call.
1115      # So, we raise an error when no name is provided.
1116      # We will do the same for symbolic mode for consistency although a name
1117      # will be generated if no name is provided.
1118
1119      # We will not raise this error in the foll use case for the sake of
1120      # consistency as name in provided in the metric constructor.
1121      # mean = metrics.Mean(name='my_metric')
1122      # model.add_metric(mean(outputs))
1123      raise ValueError('Please provide a name for your metric like '
1124                       '`self.add_metric(tf.reduce_sum(inputs), '
1125                       'name=\'mean_activation\', aggregation=\'mean\')`')
1126    elif from_metric_obj:
1127      name = value._metric_obj.name
1128
1129    if in_call_context:
1130      # TF Function path should take the eager path.
1131      self._symbolic_add_metric(value, aggregation, name)
1132    else:
1133      if not is_symbolic:
1134        raise ValueError('Expected a symbolic Tensor for the metric value, '
1135                         'received: ' + str(value))
1136
1137      # Possible a metric was added in a Layer's `build`.
1138      if not getattr(self, '_is_graph_network', False):
1139        with backend.get_graph().as_default():
1140          self._symbolic_add_metric(value, aggregation, name)
1141        return
1142
1143      if from_metric_obj:
1144        raise ValueError('Using the result of calling a `Metric` object '
1145                         'when calling `add_metric` on a Functional '
1146                         'Model is not supported. Please pass the '
1147                         'Tensor to monitor directly.')
1148
1149      # Insert layers into the Keras Graph Network.
1150      self._graph_network_add_metric(value, aggregation, name)
1151
1152  @doc_controls.for_subclass_implementers
1153  def add_update(self, updates, inputs=None):
1154    """Add update op(s), potentially dependent on layer inputs.
1155
1156    Weight updates (for instance, the updates of the moving mean and variance
1157    in a BatchNormalization layer) may be dependent on the inputs passed
1158    when calling a layer. Hence, when reusing the same layer on
1159    different inputs `a` and `b`, some entries in `layer.updates` may be
1160    dependent on `a` and some on `b`. This method automatically keeps track
1161    of dependencies.
1162
1163    The `get_updates_for` method allows to retrieve the updates relevant to a
1164    specific set of inputs.
1165
1166    This call is ignored when eager execution is enabled (in that case, variable
1167    updates are run on the fly and thus do not need to be tracked for later
1168    execution).
1169
1170    Args:
1171      updates: Update op, or list/tuple of update ops, or zero-arg callable
1172        that returns an update op. A zero-arg callable should be passed in
1173        order to disable running the updates by setting `trainable=False`
1174        on this Layer, when executing in Eager mode.
1175      inputs: Deprecated, will be automatically inferred.
1176    """
1177    if inputs is not None:
1178      tf_logging.warning(
1179          '`add_update` `inputs` kwarg has been deprecated. You no longer need '
1180          'to pass a value to `inputs` as it is being automatically inferred.')
1181    call_context = base_layer_utils.call_context()
1182
1183    if (ds_context.has_strategy() and
1184        ds_context.in_cross_replica_context() and
1185        # When saving the model, the distribution strategy context should be
1186        # ignored, following the default path for adding updates.
1187        not call_context.saving):
1188      # Updates don't need to be run in a cross-replica context.
1189      return
1190
1191    updates = generic_utils.to_list(updates)
1192
1193    if call_context.in_call:
1194      relevant_inputs = call_context.inputs
1195    else:
1196      inbound_nodes = getattr(self, '_inbound_nodes', [])
1197      relevant_inputs = [node.input_tensors for node in inbound_nodes]
1198
1199    def process_update(x):
1200      """Standardize update ops.
1201
1202      Args:
1203        x: Tensor, op, or callable.
1204
1205      Returns:
1206        An update op.
1207      """
1208      if callable(x):
1209        update = lambda: process_update(x())
1210        return update()
1211      elif isinstance(x, ops.Operation):
1212        update = x
1213      elif hasattr(x, 'op'):
1214        update = x.op
1215      else:
1216        update = ops.convert_to_tensor_v2_with_dispatch(x)
1217
1218      reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update])
1219      update._unconditional_update = update not in reachable
1220      return update
1221
1222    updates = [process_update(x) for x in updates]
1223    self._updates.extend(updates)
1224
1225  def set_weights(self, weights):
1226    """Sets the weights of the layer, from Numpy arrays.
1227
1228    The weights of a layer represent the state of the layer. This function
1229    sets the weight values from numpy arrays. The weight values should be
1230    passed in the order they are created by the layer. Note that the layer's
1231    weights must be instantiated before calling this function by calling
1232    the layer.
1233
1234    For example, a Dense layer returns a list of two values-- per-output
1235    weights and the bias value. These can be used to set the weights of another
1236    Dense layer:
1237
1238    >>> a = tf.keras.layers.Dense(1,
1239    ...   kernel_initializer=tf.constant_initializer(1.))
1240    >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]]))
1241    >>> a.get_weights()
1242    [array([[1.],
1243           [1.],
1244           [1.]], dtype=float32), array([0.], dtype=float32)]
1245    >>> b = tf.keras.layers.Dense(1,
1246    ...   kernel_initializer=tf.constant_initializer(2.))
1247    >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]]))
1248    >>> b.get_weights()
1249    [array([[2.],
1250           [2.],
1251           [2.]], dtype=float32), array([0.], dtype=float32)]
1252    >>> b.set_weights(a.get_weights())
1253    >>> b.get_weights()
1254    [array([[1.],
1255           [1.],
1256           [1.]], dtype=float32), array([0.], dtype=float32)]
1257
1258    Args:
1259        weights: a list of Numpy arrays. The number
1260            of arrays and their shape must match
1261            number of the dimensions of the weights
1262            of the layer (i.e. it should match the
1263            output of `get_weights`).
1264
1265    Raises:
1266        ValueError: If the provided weights list does not match the
1267            layer's specifications.
1268    """
1269    params = self.weights
1270
1271    expected_num_weights = 0
1272    for param in params:
1273      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1274        expected_num_weights += param.num_tensors
1275      else:
1276        expected_num_weights += 1
1277
1278    if expected_num_weights != len(weights):
1279      raise ValueError(
1280          'You called `set_weights(weights)` on layer "%s" '
1281          'with a weight list of length %s, but the layer was '
1282          'expecting %s weights. Provided weights: %s...' %
1283          (self.name, len(weights), expected_num_weights, str(weights)[:50]))
1284
1285    weight_index = 0
1286    weight_value_tuples = []
1287    for param in params:
1288      if isinstance(param, base_layer_utils.TrackableWeightHandler):
1289        num_tensors = param.num_tensors
1290        tensors = weights[weight_index:weight_index + num_tensors]
1291        param.set_weights(tensors)
1292        weight_index += num_tensors
1293      else:
1294        weight = weights[weight_index]
1295        ref_shape = param.shape
1296        if not ref_shape.is_compatible_with(weight.shape):
1297          raise ValueError(
1298              'Layer weight shape %s not compatible with provided weight '
1299              'shape %s' % (ref_shape, weight.shape))
1300        weight_value_tuples.append((param, weight))
1301        weight_index += 1
1302
1303    backend.batch_set_value(weight_value_tuples)
1304
1305  def get_weights(self):
1306    """Returns the current weights of the layer.
1307
1308    The weights of a layer represent the state of the layer. This function
1309    returns both trainable and non-trainable weight values associated with this
1310    layer as a list of Numpy arrays, which can in turn be used to load state
1311    into similarly parameterized layers.
1312
1313    For example, a Dense layer returns a list of two values-- per-output
1314    weights and the bias value. These can be used to set the weights of another
1315    Dense layer:
1316
1317    >>> a = tf.keras.layers.Dense(1,
1318    ...   kernel_initializer=tf.constant_initializer(1.))
1319    >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]]))
1320    >>> a.get_weights()
1321    [array([[1.],
1322           [1.],
1323           [1.]], dtype=float32), array([0.], dtype=float32)]
1324    >>> b = tf.keras.layers.Dense(1,
1325    ...   kernel_initializer=tf.constant_initializer(2.))
1326    >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]]))
1327    >>> b.get_weights()
1328    [array([[2.],
1329           [2.],
1330           [2.]], dtype=float32), array([0.], dtype=float32)]
1331    >>> b.set_weights(a.get_weights())
1332    >>> b.get_weights()
1333    [array([[1.],
1334           [1.],
1335           [1.]], dtype=float32), array([0.], dtype=float32)]
1336
1337    Returns:
1338        Weights values as a list of numpy arrays.
1339    """
1340    weights = self.weights
1341    output_weights = []
1342    for weight in weights:
1343      if isinstance(weight, base_layer_utils.TrackableWeightHandler):
1344        output_weights.extend(weight.get_tensors())
1345      else:
1346        output_weights.append(weight)
1347    return backend.batch_get_value(output_weights)
1348
1349  def get_updates_for(self, inputs):
1350    """Retrieves updates relevant to a specific set of inputs.
1351
1352    Args:
1353      inputs: Input tensor or list/tuple of input tensors.
1354
1355    Returns:
1356      List of update ops of the layer that depend on `inputs`.
1357    """
1358    if inputs is None:
1359      # Requesting unconditional updates.
1360      return [u for u in self.updates if u._unconditional_update]
1361
1362    # Requesting input-conditional updates.
1363    updates = [u for u in self.updates if not u._unconditional_update]
1364    inputs = nest.flatten(inputs)
1365    reachable = tf_utils.get_reachable_from_inputs(inputs, updates)
1366    return [u for u in updates if u in reachable]
1367
1368  def get_losses_for(self, inputs):
1369    """Retrieves losses relevant to a specific set of inputs.
1370
1371    Args:
1372      inputs: Input tensor or list/tuple of input tensors.
1373
1374    Returns:
1375      List of loss tensors of the layer that depend on `inputs`.
1376    """
1377    if inputs is None:
1378      # Requesting unconditional losses.
1379      return [l for l in self.losses if l._unconditional_loss]
1380
1381    # Requesting input-conditional losses.
1382    losses = [l for l in self.losses if not l._unconditional_loss]
1383    inputs = nest.flatten(inputs)
1384    reachable = tf_utils.get_reachable_from_inputs(inputs, losses)
1385    return [l for l in losses if l in reachable]
1386
1387  def get_input_mask_at(self, node_index):
1388    """Retrieves the input mask tensor(s) of a layer at a given node.
1389
1390    Args:
1391        node_index: Integer, index of the node
1392            from which to retrieve the attribute.
1393            E.g. `node_index=0` will correspond to the
1394            first time the layer was called.
1395
1396    Returns:
1397        A mask tensor
1398        (or list of tensors if the layer has multiple inputs).
1399    """
1400    inputs = self.get_input_at(node_index)
1401    if isinstance(inputs, list):
1402      return [getattr(x, '_keras_mask', None) for x in inputs]
1403    else:
1404      return getattr(inputs, '_keras_mask', None)
1405
1406  def get_output_mask_at(self, node_index):
1407    """Retrieves the output mask tensor(s) of a layer at a given node.
1408
1409    Args:
1410        node_index: Integer, index of the node
1411            from which to retrieve the attribute.
1412            E.g. `node_index=0` will correspond to the
1413            first time the layer was called.
1414
1415    Returns:
1416        A mask tensor
1417        (or list of tensors if the layer has multiple outputs).
1418    """
1419    output = self.get_output_at(node_index)
1420    if isinstance(output, list):
1421      return [getattr(x, '_keras_mask', None) for x in output]
1422    else:
1423      return getattr(output, '_keras_mask', None)
1424
1425  @property
1426  def input_mask(self):
1427    """Retrieves the input mask tensor(s) of a layer.
1428
1429    Only applicable if the layer has exactly one inbound node,
1430    i.e. if it is connected to one incoming layer.
1431
1432    Returns:
1433        Input mask tensor (potentially None) or list of input
1434        mask tensors.
1435
1436    Raises:
1437        AttributeError: if the layer is connected to
1438        more than one incoming layers.
1439    """
1440    inputs = self.input
1441    if isinstance(inputs, list):
1442      return [getattr(x, '_keras_mask', None) for x in inputs]
1443    else:
1444      return getattr(inputs, '_keras_mask', None)
1445
1446  @property
1447  def output_mask(self):
1448    """Retrieves the output mask tensor(s) of a layer.
1449
1450    Only applicable if the layer has exactly one inbound node,
1451    i.e. if it is connected to one incoming layer.
1452
1453    Returns:
1454        Output mask tensor (potentially None) or list of output
1455        mask tensors.
1456
1457    Raises:
1458        AttributeError: if the layer is connected to
1459        more than one incoming layers.
1460    """
1461    output = self.output
1462    if isinstance(output, list):
1463      return [getattr(x, '_keras_mask', None) for x in output]
1464    else:
1465      return getattr(output, '_keras_mask', None)
1466
1467  def get_input_shape_at(self, node_index):
1468    """Retrieves the input shape(s) of a layer at a given node.
1469
1470    Args:
1471        node_index: Integer, index of the node
1472            from which to retrieve the attribute.
1473            E.g. `node_index=0` will correspond to the
1474            first time the layer was called.
1475
1476    Returns:
1477        A shape tuple
1478        (or list of shape tuples if the layer has multiple inputs).
1479
1480    Raises:
1481      RuntimeError: If called in Eager mode.
1482    """
1483    return self._get_node_attribute_at_index(node_index, 'input_shapes',
1484                                             'input shape')
1485
1486  def get_output_shape_at(self, node_index):
1487    """Retrieves the output shape(s) of a layer at a given node.
1488
1489    Args:
1490        node_index: Integer, index of the node
1491            from which to retrieve the attribute.
1492            E.g. `node_index=0` will correspond to the
1493            first time the layer was called.
1494
1495    Returns:
1496        A shape tuple
1497        (or list of shape tuples if the layer has multiple outputs).
1498
1499    Raises:
1500      RuntimeError: If called in Eager mode.
1501    """
1502    return self._get_node_attribute_at_index(node_index, 'output_shapes',
1503                                             'output shape')
1504
1505  def get_input_at(self, node_index):
1506    """Retrieves the input tensor(s) of a layer at a given node.
1507
1508    Args:
1509        node_index: Integer, index of the node
1510            from which to retrieve the attribute.
1511            E.g. `node_index=0` will correspond to the
1512            first input node of the layer.
1513
1514    Returns:
1515        A tensor (or list of tensors if the layer has multiple inputs).
1516
1517    Raises:
1518      RuntimeError: If called in Eager mode.
1519    """
1520    return self._get_node_attribute_at_index(node_index, 'input_tensors',
1521                                             'input')
1522
1523  def get_output_at(self, node_index):
1524    """Retrieves the output tensor(s) of a layer at a given node.
1525
1526    Args:
1527        node_index: Integer, index of the node
1528            from which to retrieve the attribute.
1529            E.g. `node_index=0` will correspond to the
1530            first output node of the layer.
1531
1532    Returns:
1533        A tensor (or list of tensors if the layer has multiple outputs).
1534
1535    Raises:
1536      RuntimeError: If called in Eager mode.
1537    """
1538    return self._get_node_attribute_at_index(node_index, 'output_tensors',
1539                                             'output')
1540
1541  @property
1542  def input(self):
1543    """Retrieves the input tensor(s) of a layer.
1544
1545    Only applicable if the layer has exactly one input,
1546    i.e. if it is connected to one incoming layer.
1547
1548    Returns:
1549        Input tensor or list of input tensors.
1550
1551    Raises:
1552      RuntimeError: If called in Eager mode.
1553      AttributeError: If no inbound nodes are found.
1554    """
1555    if not self._inbound_nodes:
1556      raise AttributeError('Layer ' + self.name +
1557                           ' is not connected, no input to return.')
1558    return self._get_node_attribute_at_index(0, 'input_tensors', 'input')
1559
1560  @property
1561  def output(self):
1562    """Retrieves the output tensor(s) of a layer.
1563
1564    Only applicable if the layer has exactly one output,
1565    i.e. if it is connected to one incoming layer.
1566
1567    Returns:
1568      Output tensor or list of output tensors.
1569
1570    Raises:
1571      AttributeError: if the layer is connected to more than one incoming
1572        layers.
1573      RuntimeError: if called in Eager mode.
1574    """
1575    if not self._inbound_nodes:
1576      raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
1577    return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
1578
1579  @property
1580  def input_shape(self):
1581    """Retrieves the input shape(s) of a layer.
1582
1583    Only applicable if the layer has exactly one input,
1584    i.e. if it is connected to one incoming layer, or if all inputs
1585    have the same shape.
1586
1587    Returns:
1588        Input shape, as an integer shape tuple
1589        (or list of shape tuples, one tuple per input tensor).
1590
1591    Raises:
1592        AttributeError: if the layer has no defined input_shape.
1593        RuntimeError: if called in Eager mode.
1594    """
1595    if not self._inbound_nodes:
1596      raise AttributeError('The layer has never been called '
1597                           'and thus has no defined input shape.')
1598    all_input_shapes = set(
1599        [str(node.input_shapes) for node in self._inbound_nodes])
1600    if len(all_input_shapes) == 1:
1601      return self._inbound_nodes[0].input_shapes
1602    else:
1603      raise AttributeError('The layer "' + str(self.name) +
1604                           ' has multiple inbound nodes, '
1605                           'with different input shapes. Hence '
1606                           'the notion of "input shape" is '
1607                           'ill-defined for the layer. '
1608                           'Use `get_input_shape_at(node_index)` '
1609                           'instead.')
1610
1611  def count_params(self):
1612    """Count the total number of scalars composing the weights.
1613
1614    Returns:
1615        An integer count.
1616
1617    Raises:
1618        ValueError: if the layer isn't yet built
1619          (in which case its weights aren't yet defined).
1620    """
1621    if not self.built:
1622      if getattr(self, '_is_graph_network', False):
1623        with tf_utils.maybe_init_scope(self):
1624          self._maybe_build(self.inputs)
1625      else:
1626        raise ValueError('You tried to call `count_params` on ' + self.name +
1627                         ', but the layer isn\'t built. '
1628                         'You can build it manually via: `' + self.name +
1629                         '.build(batch_input_shape)`.')
1630    return layer_utils.count_params(self.weights)
1631
1632  @property
1633  def output_shape(self):
1634    """Retrieves the output shape(s) of a layer.
1635
1636    Only applicable if the layer has one output,
1637    or if all outputs have the same shape.
1638
1639    Returns:
1640        Output shape, as an integer shape tuple
1641        (or list of shape tuples, one tuple per output tensor).
1642
1643    Raises:
1644        AttributeError: if the layer has no defined output shape.
1645        RuntimeError: if called in Eager mode.
1646    """
1647    if not self._inbound_nodes:
1648      raise AttributeError('The layer has never been called '
1649                           'and thus has no defined output shape.')
1650    all_output_shapes = set(
1651        [str(node.output_shapes) for node in self._inbound_nodes])
1652    if len(all_output_shapes) == 1:
1653      return self._inbound_nodes[0].output_shapes
1654    else:
1655      raise AttributeError('The layer "%s"'
1656                           ' has multiple inbound nodes, '
1657                           'with different output shapes. Hence '
1658                           'the notion of "output shape" is '
1659                           'ill-defined for the layer. '
1660                           'Use `get_output_shape_at(node_index)` '
1661                           'instead.' % self.name)
1662
1663  @property
1664  @doc_controls.do_not_doc_inheritable
1665  def inbound_nodes(self):
1666    """Deprecated, do NOT use! Only for compatibility with external Keras."""
1667    return self._inbound_nodes
1668
1669  @property
1670  @doc_controls.do_not_doc_inheritable
1671  def outbound_nodes(self):
1672    """Deprecated, do NOT use! Only for compatibility with external Keras."""
1673    return self._outbound_nodes
1674
1675  ##############################################################################
1676  # Methods & attributes below are public aliases of other methods.            #
1677  ##############################################################################
1678
1679  @doc_controls.do_not_doc_inheritable
1680  def apply(self, inputs, *args, **kwargs):
1681    """Deprecated, do NOT use!
1682
1683    This is an alias of `self.__call__`.
1684
1685    Args:
1686      inputs: Input tensor(s).
1687      *args: additional positional arguments to be passed to `self.call`.
1688      **kwargs: additional keyword arguments to be passed to `self.call`.
1689
1690    Returns:
1691      Output tensor(s).
1692    """
1693    warnings.warn('`layer.apply` is deprecated and '
1694                  'will be removed in a future version. '
1695                  'Please use `layer.__call__` method instead.')
1696    return self.__call__(inputs, *args, **kwargs)
1697
1698  @doc_controls.do_not_doc_inheritable
1699  def add_variable(self, *args, **kwargs):
1700    """Deprecated, do NOT use! Alias for `add_weight`."""
1701    warnings.warn('`layer.add_variable` is deprecated and '
1702                  'will be removed in a future version. '
1703                  'Please use `layer.add_weight` method instead.')
1704    return self.add_weight(*args, **kwargs)
1705
1706  @property
1707  def variables(self):
1708    """Returns the list of all layer variables/weights.
1709
1710    Alias of `self.weights`.
1711
1712    Returns:
1713      A list of variables.
1714    """
1715    return self.weights
1716
1717  @property
1718  def trainable_variables(self):
1719    return self.trainable_weights
1720
1721  @property
1722  def non_trainable_variables(self):
1723    return self.non_trainable_weights
1724
1725  ##############################################################################
1726  # Methods & attributes below are all private and only used by the framework. #
1727  ##############################################################################
1728
1729  @property
1730  def _inbound_nodes(self):
1731    return self._inbound_nodes_value
1732
1733  @_inbound_nodes.setter
1734  @trackable.no_automatic_dependency_tracking
1735  def _inbound_nodes(self, value):
1736    self._inbound_nodes_value = value
1737
1738  @property
1739  def _outbound_nodes(self):
1740    return self._outbound_nodes_value
1741
1742  @_outbound_nodes.setter
1743  @trackable.no_automatic_dependency_tracking
1744  def _outbound_nodes(self, value):
1745    self._outbound_nodes_value = value
1746
1747  def _set_dtype_policy(self, dtype):
1748    """Sets self._dtype_policy."""
1749    if isinstance(dtype, policy.Policy):
1750      self._dtype_policy = dtype
1751    elif isinstance(dtype, dict):
1752      self._dtype_policy = policy.deserialize(dtype)
1753    elif dtype:
1754      self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
1755    else:
1756      self._dtype_policy = policy.global_policy()
1757    if (self._dtype_policy.name == 'mixed_float16' and
1758        not loss_scale_optimizer.strategy_supports_loss_scaling()):
1759      # Although only loss scaling doesn't support certain strategies, to avoid
1760      # confusion, we disallow the 'mixed_float16' policy with unsupported
1761      # strategies. This is because 'mixed_float16' requires loss scaling for
1762      # numeric stability.
1763      strategy = ds_context.get_strategy()
1764      raise ValueError('Mixed precision is not supported with the '
1765                       'tf.distribute.Strategy: %s. Either stop using mixed '
1766                       'precision by removing the use of the "%s" policy or '
1767                       'use a different Strategy, e.g. a MirroredStrategy.' %
1768                       (strategy.__class__.__name__, self._dtype_policy.name))
1769
1770    # Performance optimization: cache the compute dtype as a Dtype object or
1771    # None, so that str to Dtype conversion doesn't happen in Layer.__call__.
1772    if self._dtype_policy.compute_dtype:
1773      self._compute_dtype_object = dtypes.as_dtype(
1774          self._dtype_policy.compute_dtype)
1775    else:
1776      self._compute_dtype_object = None
1777
1778  # TODO(reedwm): Expose this property?
1779  @property
1780  def _compute_dtype(self):
1781    """The layer's compute dtype.
1782
1783    Unless mixed-precision is used, this is the same as `Layer.dtype`.
1784
1785    If self._autocast is True, layer's will cast floating-point inputs to this.
1786
1787    Returns:
1788      The layer's compute dtype.
1789    """
1790    return self._dtype_policy.compute_dtype
1791
1792  def _maybe_cast_inputs(self, inputs):
1793    """Maybe casts the inputs to the compute dtype.
1794
1795    If self._compute_dtype is floating-point, and self_autocast is True,
1796    floating-point inputs are casted to self._compute_dtype.
1797
1798    Args:
1799      inputs: Input tensor, or structure of input tensors.
1800
1801    Returns:
1802      `inputs`, but tensors may have been casted to self._compute_dtype
1803    """
1804    compute_dtype = self._compute_dtype
1805    if (self._autocast and compute_dtype and
1806        dtypes.as_dtype(compute_dtype).is_floating):
1807      def f(x):
1808        """Cast a single Tensor or TensorSpec to the compute dtype."""
1809        cast_types = (ops.Tensor, sparse_tensor.SparseTensor,
1810                      ragged_tensor.RaggedTensor)
1811        if (isinstance(x, cast_types) and x.dtype.is_floating and
1812            x.dtype.base_dtype.name != compute_dtype):
1813          return math_ops.cast(x, compute_dtype)
1814        elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating:
1815          # Inputs may be TensorSpecs when this function is called from
1816          # model._set_inputs.
1817          return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name)
1818        else:
1819          return x
1820      return nest.map_structure(f, inputs)
1821    else:
1822      return inputs
1823
1824  # _dtype used to be an attribute set in the constructor. We still expose it
1825  # because some clients still use it.
1826  # TODO(reedwm): Deprecate, then remove the _dtype property.
1827  @property
1828  def _dtype(self):
1829    # This is equivalent to returning self.dtype . We do not return self.dtype
1830    # as it would cause infinite recursion in a few subclasses, which override
1831    # "dtype" to return self._dtype.
1832    return self._dtype_policy.variable_dtype
1833
1834  @_dtype.setter
1835  def _dtype(self, value):
1836    value = dtypes.as_dtype(value).name
1837    self._set_dtype_policy(policy.Policy(value))
1838
1839  def _name_scope(self):
1840    return self.name
1841
1842  def _init_set_name(self, name, zero_based=True):
1843    if not name:
1844      self._name = backend.unique_object_name(
1845          generic_utils.to_snake_case(self.__class__.__name__),
1846          zero_based=zero_based)
1847    else:
1848      self._name = name
1849
1850  def _get_existing_metric(self, name=None):
1851    match = [m for m in self._metrics if m.name == name]
1852    if not match:
1853      return
1854    if len(match) > 1:
1855      raise ValueError(
1856          'Please provide different names for the metrics you have added. '
1857          'We found {} metrics with the name: "{}"'.format(len(match), name))
1858    return match[0]
1859
1860  def _symbolic_add_metric(self, value, aggregation=None, name=None):
1861    base_layer_utils.check_graph_consistency(value, method='add_metric')
1862    match = self._get_existing_metric(name)
1863    if aggregation is None:
1864      # Iterate over the metrics and check if the given metric exists already.
1865      # This can happen when a metric instance is created in subclassed model
1866      # layer `__init__` and we have tracked that instance already in
1867      # model.__setattr__.
1868      if match:
1869        result_tensor = value
1870        metric_obj = match
1871      elif hasattr(value, '_metric_obj'):
1872        # We track the instance using the metadata on the result tensor.
1873        result_tensor = value
1874        metric_obj = result_tensor._metric_obj
1875        self._metrics.append(metric_obj)
1876      else:
1877        raise ValueError(
1878            'We do not support adding an aggregated metric result tensor that '
1879            'is not the output of a `tf.keras.metrics.Metric` metric instance. '
1880            'Without having access to the metric instance we cannot reset the '
1881            'state of a metric after every epoch during training. You can '
1882            'create a `tf.keras.metrics.Metric` instance and pass the result '
1883            'here or pass an un-aggregated result with `aggregation` parameter '
1884            'set as `mean`. For example: `self.add_metric(tf.reduce_sum(inputs)'
1885            ', name=\'mean_activation\', aggregation=\'mean\')`')
1886    else:
1887      # If a non-aggregated tensor is given as input (ie. `aggregation` is
1888      # explicitly set to `mean`), we wrap the tensor in `Mean` metric.
1889      if match:
1890        result_tensor = match(value)
1891        metric_obj = match
1892      else:
1893        metric_obj, result_tensor = base_layer_utils.create_mean_metric(
1894            value, name)
1895        self._metrics.append(metric_obj)
1896
1897  def _handle_weight_regularization(self, name, variable, regularizer):
1898    """Create lambdas which compute regularization losses."""
1899
1900    def _loss_for_variable(v):
1901      """Creates a regularization loss `Tensor` for variable `v`."""
1902      with backend.name_scope(name + '/Regularizer'):
1903        regularization = regularizer(v)
1904      return regularization
1905
1906    if base_layer_utils.is_split_variable(variable):
1907      for v in variable:
1908        self.add_loss(functools.partial(_loss_for_variable, v))
1909    else:
1910      self.add_loss(functools.partial(_loss_for_variable, variable))
1911
1912  def _handle_activity_regularization(self, inputs, outputs):
1913    # Apply activity regularization.
1914    # Note that it should be applied every time the layer creates a new
1915    # output, since it is output-specific.
1916    if self._activity_regularizer:
1917      output_list = nest.flatten(outputs)
1918      with backend.name_scope('ActivityRegularizer'):
1919        for output in output_list:
1920          activity_loss = self._activity_regularizer(output)
1921          batch_size = math_ops.cast(
1922              array_ops.shape(output)[0], activity_loss.dtype)
1923          # Make activity regularization strength batch-agnostic.
1924          mean_activity_loss = activity_loss / batch_size
1925          base_layer_utils.check_graph_consistency(
1926              mean_activity_loss, method='activity_regularizer')
1927          self.add_loss(mean_activity_loss, inputs=inputs)
1928
1929  def _set_mask_metadata(self, inputs, outputs, previous_mask):
1930    flat_outputs = nest.flatten(outputs)
1931
1932    mask_already_computed = (
1933        getattr(self, '_compute_output_and_mask_jointly', False) or
1934        all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
1935
1936    # Only compute the mask if the Layer explicitly supports masking or has
1937    # overridden `compute_mask`.
1938    should_compute_mask = (
1939        hasattr(self, 'compute_mask') and
1940        (self.supports_masking or
1941         not getattr(self.compute_mask, '_is_default', False)))
1942
1943    if mask_already_computed:
1944      flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs]
1945    elif not should_compute_mask:
1946      flat_masks = [None for _ in flat_outputs]
1947    else:
1948      output_masks = self.compute_mask(inputs, previous_mask)
1949      # `compute_mask` can return a single `None` even when a Layer
1950      # has multiple outputs.
1951      if output_masks is None:
1952        flat_masks = [None for _ in flat_outputs]
1953      else:
1954        flat_masks = nest.flatten(output_masks)
1955
1956    for output, mask in zip(flat_outputs, flat_masks):
1957      try:
1958        output._keras_mask = mask
1959      except AttributeError:
1960        # C Type such as np.ndarray.
1961        pass
1962
1963    if tf_utils.are_all_symbolic_tensors(flat_outputs):
1964      for output in flat_outputs:
1965        if getattr(output, '_keras_mask', None) is not None:
1966          # Do not track masks for `TensorFlowOpLayer` construction.
1967          output._keras_mask._keras_history_checked = True
1968
1969  def _collect_input_masks(self, inputs, args, kwargs):
1970    """Checks if `mask` argument was passed, else gathers mask from inputs."""
1971    if self._call_arg_was_passed('mask', args, kwargs):
1972      return self._get_call_arg_value('mask', args, kwargs)
1973
1974    if not self._should_compute_mask:
1975      return None
1976
1977    input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
1978                                     inputs)
1979    if generic_utils.is_all_none(input_masks):
1980      return None
1981    return input_masks
1982
1983  def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
1984    if arg_name in kwargs:
1985      return True
1986    call_fn_args = self._call_fn_args
1987    if not inputs_in_args:
1988      # Ignore `inputs` arg.
1989      call_fn_args = call_fn_args[1:]
1990    if arg_name in dict(zip(call_fn_args, args)):
1991      return True
1992    return False
1993
1994  def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False):
1995    if arg_name in kwargs:
1996      return kwargs[arg_name]
1997    call_fn_args = self._call_fn_args
1998    if not inputs_in_args:
1999      # Ignore `inputs` arg.
2000      call_fn_args = call_fn_args[1:]
2001    args_dict = dict(zip(call_fn_args, args))
2002    return args_dict[arg_name]
2003
2004  def _set_call_arg_value(
2005      self, arg_name, new_value, args,
2006      kwargs, inputs_in_args=False, pop_kwarg_if_none=False):
2007    arg_pos = self._call_fn_arg_positions.get(arg_name, None)
2008    if arg_pos is not None:
2009      if not inputs_in_args:
2010        # Ignore `inputs` arg.
2011        arg_pos = arg_pos - 1
2012      if len(args) > arg_pos:
2013        args = list(args)
2014        args[arg_pos] = new_value
2015        return args, kwargs
2016    if new_value is None and pop_kwarg_if_none:
2017      kwargs.pop(arg_name, None)
2018    else:
2019      kwargs[arg_name] = new_value
2020    return args, kwargs
2021
2022  def _get_node_attribute_at_index(self, node_index, attr, attr_name):
2023    """Private utility to retrieves an attribute (e.g. inputs) from a node.
2024
2025    This is used to implement the methods:
2026        - get_input_shape_at
2027        - get_output_shape_at
2028        - get_input_at
2029        etc...
2030
2031    Args:
2032        node_index: Integer index of the node from which
2033            to retrieve the attribute.
2034        attr: Exact node attribute name.
2035        attr_name: Human-readable attribute name, for error messages.
2036
2037    Returns:
2038        The layer's attribute `attr` at the node of index `node_index`.
2039
2040    Raises:
2041        RuntimeError: If the layer has no inbound nodes, or if called in Eager
2042        mode.
2043        ValueError: If the index provided does not match any node.
2044    """
2045    if not self._inbound_nodes:
2046      raise RuntimeError('The layer has never been called '
2047                         'and thus has no defined ' + attr_name + '.')
2048    if not len(self._inbound_nodes) > node_index:
2049      raise ValueError('Asked to get ' + attr_name + ' at node ' +
2050                       str(node_index) + ', but the layer has only ' +
2051                       str(len(self._inbound_nodes)) + ' inbound nodes.')
2052    values = getattr(self._inbound_nodes[node_index], attr)
2053    if isinstance(values, list) and len(values) == 1:
2054      return values[0]
2055    else:
2056      return values
2057
2058  def _maybe_build(self, inputs):
2059    # Check input assumptions set before layer building, e.g. input rank.
2060    if not self.built:
2061      input_spec.assert_input_compatibility(
2062          self.input_spec, inputs, self.name)
2063      input_list = nest.flatten(inputs)
2064      if input_list and self._dtype_policy.compute_dtype is None:
2065        try:
2066          dtype = input_list[0].dtype.base_dtype.name
2067        except AttributeError:
2068          pass
2069        else:
2070          self._set_dtype_policy(policy.Policy(dtype))
2071      input_shapes = None
2072      if all(hasattr(x, 'shape') for x in input_list):
2073        input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2074      # Only call `build` if the user has manually overridden the build method.
2075      if not hasattr(self.build, '_is_default'):
2076        # Any setup work performed only once should happen in an `init_scope`
2077        # to avoid creating symbolic Tensors that will later pollute any eager
2078        # operations.
2079        with tf_utils.maybe_init_scope(self):
2080          self.build(input_shapes)
2081      # We must set also ensure that the layer is marked as built, and the build
2082      # shape is stored since user defined build functions may not be calling
2083      # `super.build()`
2084      Layer.build(self, input_shapes)
2085
2086    # Optionally load weight values specified at layer instantiation.
2087    if self._initial_weights is not None:
2088      self.set_weights(self._initial_weights)
2089      self._initial_weights = None
2090
2091  def _symbolic_call(self, inputs):
2092    input_shapes = nest.map_structure(lambda x: x.shape, inputs)
2093    output_shapes = self.compute_output_shape(input_shapes)
2094
2095    def _make_placeholder_like(shape):
2096      ph = backend.placeholder(shape=shape, dtype=self.dtype)
2097      ph._keras_mask = None
2098      return ph
2099
2100    return nest.map_structure(_make_placeholder_like, output_shapes)
2101
2102  def _get_trainable_state(self):
2103    """Get the `trainable` state of each sublayer.
2104
2105    Returns:
2106      A dict mapping all sublayers to their `trainable` value.
2107    """
2108    layers = self._flatten_layers(include_self=False, recursive=False)
2109    trainable_state = {self: self.trainable}
2110    for l in layers:
2111      trainable_state.update(l._get_trainable_state())
2112    return trainable_state
2113
2114  def _set_trainable_state(self, trainable_state):
2115    """Set `trainable` state for each sublayer."""
2116    if self in trainable_state:
2117      self.trainable = trainable_state[self]
2118    layers = self._flatten_layers(include_self=False, recursive=False)
2119    for l in layers:
2120      if l in trainable_state:
2121        l._set_trainable_state(trainable_state)
2122
2123  @property
2124  def _obj_reference_counts(self):
2125    """A dictionary counting the number of attributes referencing an object."""
2126    self._maybe_create_attribute('_obj_reference_counts_dict',
2127                                 object_identity.ObjectIdentityDictionary())
2128    return self._obj_reference_counts_dict
2129
2130  @trackable.no_automatic_dependency_tracking
2131  def _maybe_create_attribute(self, name, default_value):
2132    """Create the attribute with the default value if it hasn't been created.
2133
2134    This is useful for fields that is used for tracking purpose,
2135    _trainable_weights, or _layers. Note that user could create a layer subclass
2136    and assign an internal field before invoking the Layer.__init__(), the
2137    __setattr__() need to create the tracking fields and __init__() need to not
2138    override them.
2139
2140    Args:
2141      name: String, the name of the attribute.
2142      default_value: Object, the default value of the attribute.
2143    """
2144    if not hasattr(self, name):
2145      self.__setattr__(name, default_value)
2146
2147  def __delattr__(self, name):
2148    # For any super.__delattr__() call, we will directly use the implementation
2149    # in Trackable and skip the behavior in AutoTrackable. The Layer was
2150    # originally use Trackable as base class, the change of using Module as base
2151    # class forced us to have AutoTrackable in the class hierarchy. Skipping
2152    # the __delattr__ and __setattr__ in AutoTrackable will keep the status quo.
2153    existing_value = getattr(self, name, None)
2154
2155    # If this value is replacing an existing object assigned to an attribute, we
2156    # should clean it out to avoid leaking memory. First we check if there are
2157    # other attributes referencing it.
2158    reference_counts = self._obj_reference_counts
2159    if existing_value not in reference_counts:
2160      super(tracking.AutoTrackable, self).__delattr__(name)
2161      return
2162
2163    reference_count = reference_counts[existing_value]
2164    if reference_count > 1:
2165      # There are other remaining references. We can't remove this object from
2166      # _layers etc.
2167      reference_counts[existing_value] = reference_count - 1
2168      super(tracking.AutoTrackable, self).__delattr__(name)
2169      return
2170    else:
2171      # This is the last remaining reference.
2172      del reference_counts[existing_value]
2173
2174    super(tracking.AutoTrackable, self).__delattr__(name)
2175
2176    if (isinstance(existing_value, Layer)
2177        or base_layer_utils.has_weights(existing_value)):
2178      super(tracking.AutoTrackable, self).__setattr__(
2179          '_self_tracked_trackables',
2180          [l for l in self._self_tracked_trackables if l is not existing_value])
2181    if isinstance(existing_value, tf_variables.Variable):
2182      super(tracking.AutoTrackable, self).__setattr__(
2183          '_trainable_weights',
2184          [w for w in self._trainable_weights if w is not existing_value])
2185      super(tracking.AutoTrackable, self).__setattr__(
2186          '_non_trainable_weights',
2187          [w for w in self._non_trainable_weights if w is not existing_value])
2188
2189  def __setattr__(self, name, value):
2190    if (name == '_self_setattr_tracking' or
2191        not getattr(self, '_self_setattr_tracking', True) or
2192        # Exclude @property.setters from tracking
2193        hasattr(self.__class__, name)):
2194      try:
2195        super(tracking.AutoTrackable, self).__setattr__(name, value)
2196      except AttributeError:
2197        raise AttributeError(
2198            ('Can\'t set the attribute "{}", likely because it conflicts with '
2199             'an existing read-only @property of the object. Please choose a '
2200             'different name.').format(name))
2201      return
2202
2203    # Keep track of trackable objects, for the needs of `Network.save_weights`.
2204    value = data_structures.sticky_attribute_assignment(
2205        trackable=self, value=value, name=name)
2206
2207    reference_counts = self._obj_reference_counts
2208    reference_counts[value] = reference_counts.get(value, 0) + 1
2209
2210    # Clean out the old attribute, which clears _layers and _trainable_weights
2211    # if necessary.
2212    try:
2213      self.__delattr__(name)
2214    except AttributeError:
2215      pass
2216
2217    # Keep track of metric instance created in subclassed layer.
2218    from tensorflow.python.keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
2219    for val in nest.flatten(value):
2220      if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'):
2221        self._metrics.append(val)
2222
2223    # TODO(scottzhu): Need to track Module object as well for weight tracking.
2224    # Be careful about metric if it becomes a Module in future.
2225    # Append value to self._layers if relevant
2226    if (getattr(self, '_auto_track_sub_layers', True) and
2227        (isinstance(value, Layer) or base_layer_utils.has_weights(value))):
2228      self._maybe_create_attribute('_self_tracked_trackables', [])
2229      # We need to check object identity to avoid de-duplicating empty
2230      # container types which compare equal.
2231      if not any((layer is value for layer in self._self_tracked_trackables)):
2232        self._self_tracked_trackables.append(value)
2233        if hasattr(value, '_use_resource_variables'):
2234          # Legacy layers (V1 tf.layers) must always use
2235          # resource variables.
2236          value._use_resource_variables = True
2237
2238    # Append value to list of trainable / non-trainable weights if relevant
2239    # TODO(b/125122625): This won't pick up on any variables added to a
2240    # list/dict after creation.
2241    for val in nest.flatten(value):
2242      if not isinstance(val, tf_variables.Variable):
2243        continue
2244
2245      # Users may add extra weights/variables
2246      # simply by assigning them to attributes (invalid for graph networks)
2247      self._maybe_create_attribute('_trainable_weights', [])
2248      self._maybe_create_attribute('_non_trainable_weights', [])
2249      if val.trainable:
2250        if any(val is w for w in self._trainable_weights):
2251          continue
2252        self._trainable_weights.append(val)
2253      else:
2254        if any(val is w for w in self._non_trainable_weights):
2255          continue
2256        self._non_trainable_weights.append(val)
2257
2258      backend.track_variable(val)
2259
2260    # Skip the auto trackable from tf.Module to keep status quo. See the comment
2261    # at __delattr__.
2262    super(tracking.AutoTrackable, self).__setattr__(name, value)
2263
2264  # This is a hack so that the is_layer (within
2265  # training/trackable/layer_utils.py) check doesn't get the weights attr.
2266  # TODO(b/110718070): Remove when fixed.
2267  def _is_layer(self):
2268    return True
2269
2270  def _init_call_fn_args(self):
2271    # Clear cached call function arguments.
2272    self.__class__._call_full_argspec.fget.cache.pop(self, None)
2273    self.__class__._call_fn_args.fget.cache.pop(self, None)
2274    self.__class__._call_accepts_kwargs.fget.cache.pop(self, None)
2275
2276    call_fn_args = self._call_fn_args
2277    self._expects_training_arg = ('training' in call_fn_args or
2278                                  self._call_accepts_kwargs)
2279    self._expects_mask_arg = ('mask' in call_fn_args or
2280                              self._call_accepts_kwargs)
2281
2282  @property
2283  @layer_utils.cached_per_instance
2284  def _call_full_argspec(self):
2285    # Argspec inspection is expensive and the call spec is used often, so it
2286    # makes sense to cache the result.
2287    return tf_inspect.getfullargspec(self.call)
2288
2289  @property
2290  @layer_utils.cached_per_instance
2291  def _call_fn_args(self):
2292    all_args = self._call_full_argspec.args
2293    # Scrub `self` that appears if a decorator was applied.
2294    if all_args and all_args[0] == 'self':
2295      return all_args[1:]
2296    return all_args
2297
2298  @property
2299  @layer_utils.cached_per_instance
2300  def _call_fn_arg_positions(self):
2301    call_fn_arg_positions = dict()
2302    for pos, arg in enumerate(self._call_fn_args):
2303      call_fn_arg_positions[arg] = pos
2304    return call_fn_arg_positions
2305
2306  @property
2307  @layer_utils.cached_per_instance
2308  def _call_accepts_kwargs(self):
2309    return self._call_full_argspec.varkw is not None
2310
2311  @property
2312  @layer_utils.cached_per_instance
2313  def _should_compute_mask(self):
2314    return ('mask' in self._call_fn_args or
2315            getattr(self, 'compute_mask', None) is not None)
2316
2317  def _dedup_weights(self, weights):
2318    """Dedupe weights while maintaining order as much as possible."""
2319    output, seen_ids = [], set()
2320    for w in weights:
2321      if id(w) not in seen_ids:
2322        output.append(w)
2323        # Track the Variable's identity to avoid __eq__ issues.
2324        seen_ids.add(id(w))
2325
2326    return output
2327
2328  # SavedModel properties. Please see keras/saving/saved_model for details.
2329
2330  @property
2331  def _trackable_saved_model_saver(self):
2332    return layer_serialization.LayerSavedModelSaver(self)
2333
2334  @property
2335  def _object_identifier(self):
2336    return self._trackable_saved_model_saver.object_identifier
2337
2338  @property
2339  def _tracking_metadata(self):
2340    return self._trackable_saved_model_saver.tracking_metadata
2341
2342  def _list_extra_dependencies_for_serialization(self, serialization_cache):
2343    return (self._trackable_saved_model_saver
2344            .list_extra_dependencies_for_serialization(serialization_cache))
2345
2346  def _list_functions_for_serialization(self, serialization_cache):
2347    return (self._trackable_saved_model_saver
2348            .list_functions_for_serialization(serialization_cache))
2349
2350  def __getstate__(self):
2351    # Override to support `copy.deepcopy` and pickling.
2352    # Thread-local objects cannot be copied in Python 3, so pop these.
2353    # Thread-local objects are used to cache losses in MirroredStrategy, and
2354    # so shouldn't be copied.
2355    state = self.__dict__.copy()
2356    state.pop('_thread_local', None)
2357    return state
2358
2359  def __setstate__(self, state):
2360    state['_thread_local'] = threading.local()
2361    # Bypass Trackable logic as `__dict__` already contains this info.
2362    object.__setattr__(self, '__dict__', state)
2363
2364
2365class KerasHistory(
2366    collections.namedtuple('KerasHistory',
2367                           ['layer', 'node_index', 'tensor_index'])):
2368  """Tracks the Layer call that created a Tensor, for Keras Graph Networks.
2369
2370  During construction of Keras Graph Networks, this metadata is added to
2371  each Tensor produced as the output of a Layer, starting with an
2372  `InputLayer`. This allows Keras to track how each Tensor was produced, and
2373  this information is later retraced by the `keras.engine.Network` class to
2374  reconstruct the Keras Graph Network.
2375
2376  Attributes:
2377    layer: The Layer that produced the Tensor.
2378    node_index: The specific call to the Layer that produced this Tensor. Layers
2379      can be called multiple times in order to share weights. A new node is
2380      created every time a Tensor is called.
2381    tensor_index: The output index for this Tensor. Always zero if the Layer
2382      that produced this Tensor only has one output. Nested structures of
2383      Tensors are deterministically assigned an index via `nest.flatten`.
2384  """
2385  # Added to maintain memory and performance characteristics of `namedtuple`
2386  # while subclassing.
2387  __slots__ = ()
2388
2389
2390# Avoid breaking users who directly import this symbol from this file.
2391# TODO(fchollet): remove this.
2392InputSpec = input_spec.InputSpec  # pylint:disable=invalid-name
2393