1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the base Layer class, from which all layers inherit.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import itertools 24import threading 25import warnings 26 27import numpy as np 28import six 29from six.moves import zip # pylint: disable=redefined-builtin 30 31from tensorflow.python.autograph.core import ag_ctx 32from tensorflow.python.autograph.impl import api as autograph 33from tensorflow.python.distribute import distribution_strategy_context as ds_context 34from tensorflow.python.eager import context 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import func_graph 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.framework import tensor_util 42from tensorflow.python.keras import backend 43from tensorflow.python.keras import constraints 44from tensorflow.python.keras import initializers 45from tensorflow.python.keras import regularizers 46from tensorflow.python.keras.engine import base_layer 47from tensorflow.python.keras.engine import base_layer_utils 48from tensorflow.python.keras.engine import input_spec 49from tensorflow.python.keras.mixed_precision import autocast_variable 50from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 51from tensorflow.python.keras.mixed_precision import policy 52from tensorflow.python.keras.saving.saved_model import layer_serialization 53from tensorflow.python.keras.utils import generic_utils 54from tensorflow.python.keras.utils import layer_utils 55from tensorflow.python.keras.utils import object_identity 56from tensorflow.python.keras.utils import tf_inspect 57from tensorflow.python.keras.utils import tf_utils 58# A module that only depends on `keras.layers` import these from here. 59from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import 60from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import 61from tensorflow.python.module import module 62from tensorflow.python.ops import array_ops 63from tensorflow.python.ops import math_ops 64from tensorflow.python.ops import variables as tf_variables 65from tensorflow.python.ops.ragged import ragged_tensor 66from tensorflow.python.platform import tf_logging 67from tensorflow.python.training.tracking import base as trackable 68from tensorflow.python.training.tracking import data_structures 69from tensorflow.python.training.tracking import tracking 70from tensorflow.python.util import nest 71from tensorflow.tools.docs import doc_controls 72 73 74# pylint: disable=g-classes-have-attributes 75class Layer(base_layer.Layer): 76 """Base layer class. 77 78 This is the class from which all layers inherit. 79 80 A layer is a class implementing common neural networks operations, such 81 as convolution, batch norm, etc. These operations require managing weights, 82 losses, updates, and inter-layer connectivity. 83 84 Users will just instantiate a layer and then treat it as a callable. 85 86 We recommend that descendants of `Layer` implement the following methods: 87 88 * `__init__()`: Save configuration in member variables 89 * `build()`: Called once from `__call__`, when we know the shapes of inputs 90 and `dtype`. Should have the calls to `add_weight()`, and then 91 call the super's `build()` (which sets `self.built = True`, which is 92 nice in case the user wants to call `build()` manually before the 93 first `__call__`). 94 * `call()`: Called in `__call__` after making sure `build()` has been called 95 once. Should actually perform the logic of applying the layer to the 96 input tensors (which should be passed in as the first argument). 97 98 Args: 99 trainable: Boolean, whether the layer's variables should be trainable. 100 name: String name of the layer. 101 dtype: The dtype of the layer's computations and weights (default of 102 `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type 103 of the first input in TensorFlow 1). 104 dynamic: Set this to `True` if your layer should only be run eagerly, and 105 should not be used to generate a static computation graph. 106 This would be the case for a Tree-RNN or a recursive network, 107 for example, or generally for any layer that manipulates tensors 108 using Python control flow. If `False`, we assume that the layer can 109 safely be used to generate a static computation graph. 110 111 Attributes: 112 name: The name of the layer (string). 113 dtype: The dtype of the layer's computations and weights. If mixed 114 precision is used with a `tf.keras.mixed_precision.Policy`, this is 115 instead just the dtype of the layer's weights, as the computations are 116 done in a different dtype. 117 updates: List of update ops of this layer. 118 losses: List of losses added by this layer. 119 trainable_weights: List of variables to be included in backprop. 120 non_trainable_weights: List of variables that should not be 121 included in backprop. 122 weights: The concatenation of the lists trainable_weights and 123 non_trainable_weights (in this order). 124 trainable: Whether the layer should be trained (boolean). 125 input_spec: Optional (list of) `InputSpec` object(s) specifying the 126 constraints on inputs that can be accepted by the layer. 127 128 Each layer has a dtype, which is typically the dtype of the layer's 129 computations and variables. A layer's dtype can be queried via the 130 `Layer.dtype` property. The dtype is specified with the `dtype` constructor 131 argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()` 132 if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, 133 layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed 134 precision is used, layers may have different computation and variable dtypes. 135 See `tf.keras.mixed_precision.Policy` for details on layer dtypes. 136 """ 137 138 # See tf.Module for the usage of this property. 139 # The key for _obj_reference_counts_dict is a Trackable, which could be a 140 # variable or layer etc. tf.Module._flatten will fail to flatten the key 141 # since it is trying to convert Trackable to a string. This attribute can be 142 # ignored even after the fix of nest lib, since the trackable object should 143 # already been available as individual attributes. _obj_reference_counts_dict 144 # just contains a copy of them. 145 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 146 ('_obj_reference_counts_dict',), 147 module.Module._TF_MODULE_IGNORED_PROPERTIES 148 )) 149 150 @trackable.no_automatic_dependency_tracking 151 def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, 152 **kwargs): 153 self._instrument_layer_creation() 154 155 # These properties should be set by the user via keyword arguments. 156 # note that 'dtype', 'input_shape' and 'batch_input_shape' 157 # are only applicable to input layers: do not pass these keywords 158 # to non-input layers. 159 allowed_kwargs = { 160 'input_dim', 'input_shape', 'batch_input_shape', 'batch_size', 161 'weights', 'activity_regularizer', 'autocast', 'implementation' 162 } 163 # Validate optional keyword arguments. 164 generic_utils.validate_kwargs(kwargs, allowed_kwargs) 165 166 # Mutable properties 167 # Indicates whether the layer's weights are updated during training 168 # and whether the layer's updates are run during training. 169 self._trainable = trainable 170 # A stateful layer is a layer whose updates are run during inference too, 171 # for instance stateful RNNs. 172 self._stateful = False 173 # Indicates whether `build` needs to be called upon layer call, to create 174 # the layer's weights. 175 self.built = False 176 self._build_input_shape = None 177 # Provides information about which inputs are compatible with the layer. 178 self._input_spec = None 179 self.supports_masking = False 180 181 self._init_set_name(name) 182 self._activity_regularizer = regularizers.get( 183 kwargs.pop('activity_regularizer', None)) 184 self._maybe_create_attribute('_trainable_weights', []) 185 self._maybe_create_attribute('_non_trainable_weights', []) 186 self._updates = [] 187 # Object to store all thread local layer properties. 188 self._thread_local = threading.local() 189 # A list of zero-argument lambdas which return Tensors, used for variable 190 # regularizers. 191 self._callable_losses = [] 192 # A list of symbolic Tensors containing activity regularizers and losses 193 # manually added through `add_loss` in graph-building mode. 194 self._losses = [] 195 # A list of metric instances corresponding to the symbolic metric tensors 196 # added using the `add_metric` API. 197 self._metrics = [] 198 199 # Both graph and subclassed networks have a dtype policy. For graph 200 # networks, the policy's compute and variable dtypes are ignored. Such 201 # networks only use the policy if it is a PolicyV1, in which case it uses 202 # the PolicyV1's loss_scale (Policy does not have a loss_scale). For 203 # subclassed networks, the compute and variable dtypes are used as like any 204 # ordinary layer. 205 self._set_dtype_policy(dtype) 206 # Boolean indicating whether the layer automatically casts its inputs to the 207 # layer's compute_dtype. 208 self._autocast = kwargs.get('autocast', 209 base_layer_utils.v2_dtype_behavior_enabled()) 210 211 # Dependencies tracked via attribute assignment. 212 # All layers in order of horizontal graph traversal. 213 # Entries are unique. For models includes input and output layers. 214 self._maybe_create_attribute('_self_tracked_trackables', []) 215 216 # These lists will be filled via successive calls 217 # to self._add_inbound_node(). 218 # Used in symbolic mode only, only in conjunction with graph-networks 219 self._inbound_nodes_value = [] 220 self._outbound_nodes_value = [] 221 222 self._init_call_fn_args() 223 224 # Whether the `call` method can be used to build a TF graph without issues. 225 # This attribute has no effect if the model is created using the Functional 226 # API. Instead, `model.dynamic` is determined based on the internal layers. 227 self._dynamic = dynamic 228 229 # Manage input shape information if passed. 230 if 'input_dim' in kwargs and 'input_shape' not in kwargs: 231 # Backwards compatibility: alias 'input_dim' to 'input_shape'. 232 kwargs['input_shape'] = (kwargs['input_dim'],) 233 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 234 # In this case we will later create an input layer 235 # to insert before the current layer 236 if 'batch_input_shape' in kwargs: 237 batch_input_shape = tuple(kwargs['batch_input_shape']) 238 elif 'input_shape' in kwargs: 239 if 'batch_size' in kwargs: 240 batch_size = kwargs['batch_size'] 241 else: 242 batch_size = None 243 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 244 self._batch_input_shape = batch_input_shape 245 246 # Manage initial weight values if passed. 247 self._initial_weights = kwargs.get('weights', None) 248 249 # Whether the layer will track any layers that is set as attribute on itself 250 # as sub-layers, the weights from the sub-layers will be included in the 251 # parent layer's variables() as well. 252 # Default to True, which means auto tracking is turned on. Certain subclass 253 # might want to turn it off, like Sequential model. 254 self._auto_track_sub_layers = True 255 256 # Mark this layer as having been originally built as a tf1 layer/model 257 self._originally_built_as_v1 = True 258 259 # For backwards compat reasons, most built-in layers do not guarantee 260 # That they will 100% preserve the structure of input args when saving 261 # / loading configs. E.g. they may un-nest an arg that is 262 # a list with one element. 263 self._preserve_input_structure_in_config = False 264 265 @trackable.no_automatic_dependency_tracking 266 @generic_utils.default 267 def build(self, input_shape): 268 """Creates the variables of the layer (optional, for subclass implementers). 269 270 This is a method that implementers of subclasses of `Layer` or `Model` 271 can override if they need a state-creation step in-between 272 layer instantiation and layer call. 273 274 This is typically used to create the weights of `Layer` subclasses. 275 276 Args: 277 input_shape: Instance of `TensorShape`, or list of instances of 278 `TensorShape` if the layer expects a list of inputs 279 (one instance per input). 280 """ 281 if not hasattr(self.build, '_is_default'): 282 self._build_input_shape = input_shape 283 self.built = True 284 285 @doc_controls.for_subclass_implementers 286 def call(self, inputs, **kwargs): # pylint: disable=unused-argument 287 """This is where the layer's logic lives. 288 289 Args: 290 inputs: Input tensor, or list/tuple of input tensors. 291 **kwargs: Additional keyword arguments. 292 293 Returns: 294 A tensor or list/tuple of tensors. 295 """ 296 return inputs 297 298 @doc_controls.for_subclass_implementers 299 def _add_trackable(self, trackable_object, trainable): 300 """Adds a Trackable object to this layer's state. 301 302 Args: 303 trackable_object: The tf.tracking.Trackable object to add. 304 trainable: Boolean, whether the variable should be part of the layer's 305 "trainable_variables" (e.g. variables, biases) or 306 "non_trainable_variables" (e.g. BatchNorm mean and variance). 307 308 Returns: 309 The TrackableWeightHandler used to track this object. 310 """ 311 handler = base_layer_utils.TrackableWeightHandler(trackable_object) 312 if trainable: 313 self._trainable_weights.append(handler) 314 else: 315 self._non_trainable_weights.append(handler) 316 return handler 317 318 @doc_controls.for_subclass_implementers 319 def add_weight(self, 320 name=None, 321 shape=None, 322 dtype=None, 323 initializer=None, 324 regularizer=None, 325 trainable=None, 326 constraint=None, 327 partitioner=None, 328 use_resource=None, 329 synchronization=tf_variables.VariableSynchronization.AUTO, 330 aggregation=tf_variables.VariableAggregation.NONE, 331 **kwargs): 332 """Adds a new variable to the layer. 333 334 Args: 335 name: Variable name. 336 shape: Variable shape. Defaults to scalar if unspecified. 337 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 338 initializer: Initializer instance (callable). 339 regularizer: Regularizer instance (callable). 340 trainable: Boolean, whether the variable should be part of the layer's 341 "trainable_variables" (e.g. variables, biases) 342 or "non_trainable_variables" (e.g. BatchNorm mean and variance). 343 Note that `trainable` cannot be `True` if `synchronization` 344 is set to `ON_READ`. 345 constraint: Constraint instance (callable). 346 partitioner: Partitioner to be passed to the `Trackable` API. 347 use_resource: Whether to use `ResourceVariable`. 348 synchronization: Indicates when a distributed a variable will be 349 aggregated. Accepted values are constants defined in the class 350 `tf.VariableSynchronization`. By default the synchronization is set to 351 `AUTO` and the current `DistributionStrategy` chooses 352 when to synchronize. If `synchronization` is set to `ON_READ`, 353 `trainable` must not be set to `True`. 354 aggregation: Indicates how a distributed variable will be aggregated. 355 Accepted values are constants defined in the class 356 `tf.VariableAggregation`. 357 **kwargs: Additional keyword arguments. Accepted values are `getter`, 358 `collections`, `experimental_autocast` and `caching_device`. 359 360 Returns: 361 The created variable. Usually either a `Variable` or `ResourceVariable` 362 instance. If `partitioner` is not `None`, a `PartitionedVariable` 363 instance is returned. 364 365 Raises: 366 RuntimeError: If called with partitioned variable regularization and 367 eager execution is enabled. 368 ValueError: When giving unsupported dtype and no initializer or when 369 trainable has been set to True with synchronization set as `ON_READ`. 370 """ 371 if shape is None: 372 shape = () 373 # Validate optional keyword arguments. 374 for kwarg in kwargs: 375 if kwarg not in ['getter', 'collections', 'experimental_autocast', 376 'caching_device']: 377 raise TypeError('Unknown keyword argument:', kwarg) 378 getter = kwargs.pop('getter', base_layer_utils.make_variable) 379 collections_arg = kwargs.pop('collections', None) 380 # 'experimental_autocast' can be set to False by the caller to indicate an 381 # AutoCastVariable should never be created. 382 autocast = kwargs.pop('experimental_autocast', True) 383 # See the docstring for tf.Variable about the details for caching_device. 384 caching_device = kwargs.pop('caching_device', None) 385 386 if dtype is None: 387 dtype = self.dtype or backend.floatx() 388 dtype = dtypes.as_dtype(dtype) 389 if self._dtype_policy.variable_dtype is None: 390 # The policy is "_infer", so we infer the policy from the variable dtype. 391 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) 392 initializer = initializers.get(initializer) 393 regularizer = regularizers.get(regularizer) 394 constraint = constraints.get(constraint) 395 396 if synchronization == tf_variables.VariableSynchronization.ON_READ: 397 if trainable: 398 raise ValueError( 399 'Synchronization value can be set to ' 400 'VariableSynchronization.ON_READ only for non-trainable variables. ' 401 'You have specified trainable=True and ' 402 'synchronization=VariableSynchronization.ON_READ.') 403 else: 404 # Set trainable to be false when variable is to be synced on read. 405 trainable = False 406 elif trainable is None: 407 trainable = True 408 409 # Initialize variable when no initializer provided 410 if initializer is None: 411 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 412 if dtype.is_floating: 413 initializer = initializers.get('glorot_uniform') 414 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 415 # If dtype is DT_BOOL, provide a default value `FALSE` 416 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 417 initializer = initializers.zeros() 418 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 419 else: 420 raise ValueError('An initializer for variable %s of type %s is required' 421 ' for layer %s' % (name, dtype.base_dtype, self.name)) 422 423 if (autocast and 424 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype 425 and dtype.is_floating): 426 # Wrap 'getter' with a version that returns an AutoCastVariable. 427 old_getter = getter 428 def getter(*args, **kwargs): # pylint: disable=function-redefined 429 variable = old_getter(*args, **kwargs) 430 return autocast_variable.create_autocast_variable(variable) 431 # Also the caching_device does not work with the mixed precision API, 432 # disable it if it is specified. 433 # TODO(b/142020079): Reenable it once the bug is fixed. 434 if caching_device is not None: 435 tf_logging.warn('`caching_device` does not work with mixed precision ' 436 'API. Ignoring user specified `caching_device`.') 437 caching_device = None 438 439 variable = self._add_variable_with_custom_getter( 440 name=name, 441 shape=shape, 442 # TODO(allenl): a `make_variable` equivalent should be added as a 443 # `Trackable` method. 444 getter=getter, 445 # Manage errors in Layer rather than Trackable. 446 overwrite=True, 447 initializer=initializer, 448 dtype=dtype, 449 constraint=constraint, 450 trainable=trainable, 451 partitioner=partitioner, 452 use_resource=use_resource, 453 collections=collections_arg, 454 synchronization=synchronization, 455 aggregation=aggregation, 456 caching_device=caching_device) 457 if regularizer is not None: 458 # TODO(fchollet): in the future, this should be handled at the 459 # level of variable creation, and weight regularization losses 460 # should be variable attributes. 461 name_in_scope = variable.name[:variable.name.find(':')] 462 self._handle_weight_regularization(name_in_scope, 463 variable, 464 regularizer) 465 if base_layer_utils.is_split_variable(variable): 466 for v in variable: 467 backend.track_variable(v) 468 if trainable: 469 self._trainable_weights.append(v) 470 else: 471 self._non_trainable_weights.append(v) 472 else: 473 backend.track_variable(variable) 474 if trainable: 475 self._trainable_weights.append(variable) 476 else: 477 self._non_trainable_weights.append(variable) 478 return variable 479 480 @generic_utils.default 481 def get_config(self): 482 """Returns the config of the layer. 483 484 A layer config is a Python dictionary (serializable) 485 containing the configuration of a layer. 486 The same layer can be reinstantiated later 487 (without its trained weights) from this configuration. 488 489 The config of a layer does not include connectivity 490 information, nor the layer class name. These are handled 491 by `Network` (one layer of abstraction above). 492 493 Returns: 494 Python dictionary. 495 """ 496 all_args = tf_inspect.getfullargspec(self.__init__).args 497 config = {'name': self.name, 'trainable': self.trainable} 498 if hasattr(self, '_batch_input_shape'): 499 config['batch_input_shape'] = self._batch_input_shape 500 config['dtype'] = policy.serialize(self._dtype_policy) 501 if hasattr(self, 'dynamic'): 502 # Only include `dynamic` in the `config` if it is `True` 503 if self.dynamic: 504 config['dynamic'] = self.dynamic 505 elif 'dynamic' in all_args: 506 all_args.remove('dynamic') 507 expected_args = config.keys() 508 # Finds all arguments in the `__init__` that are not in the config: 509 extra_args = [arg for arg in all_args if arg not in expected_args] 510 # Check that either the only argument in the `__init__` is `self`, 511 # or that `get_config` has been overridden: 512 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'): 513 raise NotImplementedError('Layers with arguments in `__init__` must ' 514 'override `get_config`.') 515 return config 516 517 @classmethod 518 def from_config(cls, config): 519 """Creates a layer from its config. 520 521 This method is the reverse of `get_config`, 522 capable of instantiating the same layer from the config 523 dictionary. It does not handle layer connectivity 524 (handled by Network), nor weights (handled by `set_weights`). 525 526 Args: 527 config: A Python dictionary, typically the 528 output of get_config. 529 530 Returns: 531 A layer instance. 532 """ 533 return cls(**config) 534 535 def compute_output_shape(self, input_shape): 536 """Computes the output shape of the layer. 537 538 If the layer has not been built, this method will call `build` on the 539 layer. This assumes that the layer will later be used with inputs that 540 match the input shape provided here. 541 542 Args: 543 input_shape: Shape tuple (tuple of integers) 544 or list of shape tuples (one per output tensor of the layer). 545 Shape tuples can include None for free dimensions, 546 instead of an integer. 547 548 Returns: 549 An input shape tuple. 550 """ 551 if context.executing_eagerly(): 552 # In this case we build the model first in order to do shape inference. 553 # This is acceptable because the framework only calls 554 # `compute_output_shape` on shape values that the layer would later be 555 # built for. It would however cause issues in case a user attempts to 556 # use `compute_output_shape` manually with shapes that are incompatible 557 # with the shape the Layer will be called on (these users will have to 558 # implement `compute_output_shape` themselves). 559 self._maybe_build(input_shape) 560 with ops.get_default_graph().as_default(): 561 graph = func_graph.FuncGraph('graph') 562 with graph.as_default(): 563 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 564 inputs = nest.map_structure( 565 base_layer_utils.generate_placeholders_from_shape, input_shape) 566 try: 567 outputs = self(inputs, training=False) 568 except TypeError as e: 569 six.raise_from( 570 NotImplementedError( 571 'We could not automatically infer the static shape of the ' 572 'layer\'s output. Please implement the ' 573 '`compute_output_shape` method on your layer (%s).' % 574 self.__class__.__name__), e) 575 return nest.map_structure(lambda t: t.shape, outputs) 576 raise NotImplementedError 577 578 @doc_controls.for_subclass_implementers 579 def compute_output_signature(self, input_signature): 580 """Compute the output tensor signature of the layer based on the inputs. 581 582 Unlike a TensorShape object, a TensorSpec object contains both shape 583 and dtype information for a tensor. This method allows layers to provide 584 output dtype information if it is different from the input dtype. 585 For any layer that doesn't implement this function, 586 the framework will fall back to use `compute_output_shape`, and will 587 assume that the output dtype matches the input dtype. 588 589 Args: 590 input_signature: Single TensorSpec or nested structure of TensorSpec 591 objects, describing a candidate input for the layer. 592 593 Returns: 594 Single TensorSpec or nested structure of TensorSpec objects, describing 595 how the layer would transform the provided input. 596 597 Raises: 598 TypeError: If input_signature contains a non-TensorSpec object. 599 """ 600 def check_type_return_shape(s): 601 if not isinstance(s, tensor_spec.TensorSpec): 602 raise TypeError('Only TensorSpec signature types are supported, ' 603 'but saw signature entry: {}.'.format(s)) 604 return s.shape 605 input_shape = nest.map_structure(check_type_return_shape, input_signature) 606 output_shape = self.compute_output_shape(input_shape) 607 dtype = self._compute_dtype 608 if dtype is None: 609 input_dtypes = [s.dtype for s in nest.flatten(input_signature)] 610 # Default behavior when self.dtype is None, is to use the first input's 611 # dtype. 612 dtype = input_dtypes[0] 613 return nest.map_structure( 614 lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), 615 output_shape) 616 617 @generic_utils.default 618 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 619 """Computes an output mask tensor. 620 621 Args: 622 inputs: Tensor or list of tensors. 623 mask: Tensor or list of tensors. 624 625 Returns: 626 None or a tensor (or list of tensors, 627 one per output tensor of the layer). 628 """ 629 if not self.supports_masking: 630 if any(m is not None for m in nest.flatten(mask)): 631 raise TypeError('Layer ' + self.name + ' does not support masking, ' 632 'but was passed an input_mask: ' + str(mask)) 633 # masking not explicitly supported: return None as mask. 634 return None 635 # if masking is explicitly supported, by default 636 # carry over the input mask 637 return mask 638 639 def __call__(self, *args, **kwargs): 640 """Wraps `call`, applying pre- and post-processing steps. 641 642 Args: 643 *args: Positional arguments to be passed to `self.call`. 644 **kwargs: Keyword arguments to be passed to `self.call`. 645 646 Returns: 647 Output tensor(s). 648 649 Note: 650 - The following optional keyword arguments are reserved for specific uses: 651 * `training`: Boolean scalar tensor of Python boolean indicating 652 whether the `call` is meant for training or inference. 653 * `mask`: Boolean input mask. 654 - If the layer's `call` method takes a `mask` argument (as some Keras 655 layers do), its default value will be set to the mask generated 656 for `inputs` by the previous layer (if `input` did come from 657 a layer that generated a corresponding mask, i.e. if it came from 658 a Keras layer with masking support. 659 660 Raises: 661 ValueError: if the layer's `call` method returns None (an invalid value). 662 RuntimeError: if `super().__init__()` was not called in the constructor. 663 """ 664 self._assert_built_as_v1() 665 666 if not hasattr(self, '_thread_local'): 667 raise RuntimeError( 668 'You must call `super().__init__()` in the layer constructor.') 669 670 # Grab the first positional or keyword argument. 671 if args: 672 inputs = args[0] 673 args = args[1:] 674 elif self._call_fn_args[0] in kwargs: 675 inputs = kwargs.pop(self._call_fn_args[0]) 676 else: 677 raise ValueError( 678 'The first argument to `Layer.call` must always be passed.') 679 680 call_context = base_layer_utils.call_context() 681 input_list = nest.flatten(inputs) 682 683 # We will attempt to build a TF graph if & only if all inputs are symbolic. 684 # This is always the case in graph mode. It can also be the case in eager 685 # mode when all inputs can be traced back to `keras.Input()` (when building 686 # models using the functional API). 687 build_graph = tf_utils.are_all_symbolic_tensors(input_list) 688 689 # Accept NumPy and scalar inputs by converting to Tensors. 690 if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): 691 def _convert_non_tensor(x): 692 # Don't call `ops.convert_to_tensor` on all `inputs` because 693 # `SparseTensors` can't be converted to `Tensor`. 694 if isinstance(x, (np.ndarray, float, int)): 695 return ops.convert_to_tensor_v2_with_dispatch(x) 696 return x 697 inputs = nest.map_structure(_convert_non_tensor, inputs) 698 input_list = nest.flatten(inputs) 699 700 # Handle `mask` propagation from previous layer to current layer. Masks can 701 # be propagated explicitly via the `mask` argument, or implicitly via 702 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 703 # explicitly take priority. 704 mask_arg_passed_by_framework = False 705 input_masks = self._collect_input_masks(inputs, args, kwargs) 706 if (self._expects_mask_arg and input_masks is not None and 707 not self._call_arg_was_passed('mask', args, kwargs)): 708 mask_arg_passed_by_framework = True 709 kwargs['mask'] = input_masks 710 711 # If `training` argument is None or not explicitly passed, 712 # propagate `training` value from this layer's calling layer. 713 training_value = None 714 training_arg_passed_by_framework = False 715 # Priority 1: `training` was explicitly passed. 716 if self._call_arg_was_passed('training', args, kwargs): 717 training_value = self._get_call_arg_value('training', args, kwargs) 718 if not self._expects_training_arg: 719 kwargs.pop('training') 720 721 if training_value is None: 722 # Priority 2: `training` was passed to a parent layer. 723 if call_context.training is not None: 724 training_value = call_context.training 725 # Priority 3a: `learning_phase()` has been set. 726 elif backend.global_learning_phase_is_set(): 727 training_value = backend.learning_phase() 728 # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. 729 elif build_graph: 730 with backend.get_graph().as_default(): 731 if base_layer_utils.is_in_keras_graph(): 732 training_value = backend.learning_phase() 733 734 if self._expects_training_arg and training_value is not None: 735 # Force the training_value to be bool type which matches to the contract 736 # for layer/model call args. 737 if tensor_util.is_tf_type(training_value): 738 training_value = math_ops.cast(training_value, dtypes.bool) 739 else: 740 training_value = bool(training_value) 741 args, kwargs = self._set_call_arg_value( 742 'training', training_value, args, kwargs) 743 training_arg_passed_by_framework = True 744 745 # Only create Keras history if at least one tensor originates from a 746 # `keras.Input`. Otherwise this Layer may be being used outside the Keras 747 # framework. 748 if build_graph and base_layer_utils.needs_keras_history(inputs): 749 base_layer_utils.create_keras_history(inputs) 750 751 with call_context.enter(self, inputs, build_graph, training_value): 752 # Check input assumptions set after layer building, e.g. input shape. 753 if build_graph: 754 # Symbolic execution on symbolic tensors. We will attempt to build 755 # the corresponding TF subgraph inside `backend.get_graph()` 756 # TODO(reedwm): We should assert input compatibility after the inputs 757 # are casted, not before. 758 input_spec.assert_input_compatibility(self.input_spec, inputs, 759 self.name) 760 graph = backend.get_graph() 761 with graph.as_default(), backend.name_scope(self._name_scope()): 762 # Build layer if applicable (if the `build` method has been 763 # overridden). 764 self._maybe_build(inputs) 765 cast_inputs = self._maybe_cast_inputs(inputs) 766 767 # Wrapping `call` function in autograph to allow for dynamic control 768 # flow and control dependencies in call. We are limiting this to 769 # subclassed layers as autograph is strictly needed only for 770 # subclassed layers and models. 771 # tf_convert will respect the value of autograph setting in the 772 # enclosing tf.function, if any. 773 if (base_layer_utils.is_subclassed(self) and 774 not base_layer_utils.from_saved_model(self)): 775 call_fn = autograph.tf_convert( 776 self.call, ag_ctx.control_status_ctx()) 777 else: 778 call_fn = self.call 779 780 if not self.dynamic: 781 try: 782 with autocast_variable.enable_auto_cast_variables( 783 self._compute_dtype_object): 784 outputs = call_fn(cast_inputs, *args, **kwargs) 785 786 except errors.OperatorNotAllowedInGraphError as e: 787 raise TypeError('You are attempting to use Python control ' 788 'flow in a layer that was not declared to be ' 789 'dynamic. Pass `dynamic=True` to the class ' 790 'constructor.\nEncountered error:\n"""\n' + 791 str(e) + '\n"""') 792 else: 793 # We will use static shape inference to return symbolic tensors 794 # matching the specifications of the layer outputs. 795 # Since `self.dynamic` is True, we will never attempt to 796 # run the underlying TF graph (which is disconnected). 797 # TODO(fchollet): consider py_func as an alternative, which 798 # would enable us to run the underlying graph if needed. 799 outputs = self._symbolic_call(inputs) 800 801 if outputs is None: 802 raise ValueError('A layer\'s `call` method should return a ' 803 'Tensor or a list of Tensors, not None ' 804 '(layer: ' + self.name + ').') 805 if base_layer_utils.have_all_keras_metadata(inputs): 806 if training_arg_passed_by_framework: 807 args, kwargs = self._set_call_arg_value( 808 'training', None, args, kwargs, pop_kwarg_if_none=True) 809 if mask_arg_passed_by_framework: 810 kwargs.pop('mask') 811 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, 812 outputs) 813 self._handle_activity_regularization(inputs, outputs) 814 self._set_mask_metadata(inputs, outputs, input_masks) 815 if hasattr(self, '_set_inputs') and not self.inputs: 816 # Subclassed network: explicitly set metadata normally set by 817 # a call to self._set_inputs(). 818 # TODO(b/120997007): This should be done in Eager as well, but 819 # causes garbage collection issues because of the placeholders 820 # created on the default Keras graph. 821 self._set_inputs(inputs, outputs) 822 else: 823 # Eager execution on data tensors. 824 with backend.name_scope(self._name_scope()): 825 self._maybe_build(inputs) 826 cast_inputs = self._maybe_cast_inputs(inputs) 827 with autocast_variable.enable_auto_cast_variables( 828 self._compute_dtype_object): 829 outputs = self.call(cast_inputs, *args, **kwargs) 830 self._handle_activity_regularization(inputs, outputs) 831 self._set_mask_metadata(inputs, outputs, input_masks) 832 833 return outputs 834 835 def _assert_built_as_v1(self): 836 if not hasattr(self, '_originally_built_as_v1'): 837 raise ValueError( 838 'Your Layer or Model is in an invalid state. ' 839 'This can happen for the following cases:\n ' 840 '1. You might be interleaving estimator/non-estimator models or ' 841 'interleaving models/layers made in tf.compat.v1.Graph.as_default() ' 842 'with models/layers created outside of it. ' 843 'Converting a model to an estimator (via model_to_estimator) ' 844 'invalidates all models/layers made before the conversion (even ' 845 'if they were not the model converted to an estimator). ' 846 'Similarly, making a layer or a model inside a ' 847 'a tf.compat.v1.Graph invalidates all layers/models you previously ' 848 'made outside of the graph.\n' 849 '2. You might be using a custom keras layer implementation with ' 850 ' custom __init__ which didn\'t call super().__init__. ' 851 ' Please check the implementation of %s and its bases.' % 852 (type(self),)) 853 854 @property 855 def dtype(self): 856 return self._dtype_policy.variable_dtype 857 858 @property 859 def name(self): 860 return self._name 861 862 @property 863 def dynamic(self): 864 return any(layer._dynamic for layer in self._flatten_layers()) 865 866 @property 867 @doc_controls.do_not_generate_docs 868 def stateful(self): 869 return any(layer._stateful for layer in self._flatten_layers()) 870 871 @stateful.setter 872 def stateful(self, value): 873 self._stateful = value 874 875 @property 876 def trainable(self): 877 return self._trainable 878 879 @trainable.setter 880 def trainable(self, value): 881 self._trainable = value 882 for layer in getattr(self, '_self_tracked_trackables', []): 883 layer.trainable = value 884 885 @property 886 def activity_regularizer(self): 887 """Optional regularizer function for the output of this layer.""" 888 return self._activity_regularizer 889 890 @activity_regularizer.setter 891 def activity_regularizer(self, regularizer): 892 """Optional regularizer function for the output of this layer.""" 893 self._activity_regularizer = regularizer 894 895 @property 896 def input_spec(self): 897 return self._input_spec 898 899 @input_spec.setter 900 # Must be decorated to prevent tracking, since the input_spec can be nested 901 # InputSpec objects. 902 @trackable.no_automatic_dependency_tracking 903 def input_spec(self, value): 904 for v in nest.flatten(value): 905 if v is not None and not isinstance(v, base_layer.InputSpec): 906 raise TypeError('Layer input_spec must be an instance of InputSpec. ' 907 'Got: {}'.format(v)) 908 self._input_spec = value 909 910 @property 911 def updates(self): 912 collected_updates = [] 913 all_layers = self._flatten_layers() 914 with backend.get_graph().as_default(): 915 for layer in all_layers: 916 if not layer.trainable and not layer.stateful: 917 continue 918 for u in layer._updates: 919 if callable(u): 920 try: 921 u = u() 922 except ValueError as e: 923 if 'InaccessibleTensorError' in type(e).__name__: 924 # For one specific case of error we try to raise 925 # a more meaningful error message about the graph if we can. 926 # This error is an internal TF symbol that is not 927 # publicly exposed, so we check the name directly rather 928 # than using a direct import. 929 base_layer_utils.check_graph_consistency( 930 method='add_update', force_raise=True) 931 raise # check_graph_consistency may not always raise. 932 base_layer_utils.check_graph_consistency(u, method='add_update') 933 collected_updates.append(u) 934 return collected_updates 935 936 @property 937 def losses(self): 938 """Losses which are associated with this `Layer`. 939 940 Variable regularization tensors are created when this property is accessed, 941 so it is eager safe: accessing `losses` under a `tf.GradientTape` will 942 propagate gradients back to the corresponding variables. 943 944 Returns: 945 A list of tensors. 946 """ 947 collected_losses = [] 948 all_layers = self._flatten_layers() 949 for layer in all_layers: 950 # If any eager losses are present, we assume the model to be part of an 951 # eager training loop (either a custom one or the one used when 952 # `run_eagerly=True`) and so we always return just the eager losses. 953 collected_losses.extend(layer._losses) 954 for regularizer in layer._callable_losses: 955 loss_tensor = regularizer() 956 if loss_tensor is not None: 957 collected_losses.append(loss_tensor) 958 return collected_losses 959 960 @doc_controls.for_subclass_implementers 961 def add_loss(self, losses, inputs=None): 962 """Add loss tensor(s), potentially dependent on layer inputs. 963 964 Some losses (for instance, activity regularization losses) may be dependent 965 on the inputs passed when calling a layer. Hence, when reusing the same 966 layer on different inputs `a` and `b`, some entries in `layer.losses` may 967 be dependent on `a` and some on `b`. This method automatically keeps track 968 of dependencies. 969 970 This method can be used inside a subclassed layer or model's `call` 971 function, in which case `losses` should be a Tensor or list of Tensors. 972 973 Example: 974 975 ```python 976 class MyLayer(tf.keras.layers.Layer): 977 def call(inputs, self): 978 self.add_loss(tf.abs(tf.reduce_mean(inputs)), inputs=True) 979 return inputs 980 ``` 981 982 This method can also be called directly on a Functional Model during 983 construction. In this case, any loss Tensors passed to this Model must 984 be symbolic and be able to be traced back to the model's `Input`s. These 985 losses become part of the model's topology and are tracked in `get_config`. 986 987 Example: 988 989 ```python 990 inputs = tf.keras.Input(shape=(10,)) 991 x = tf.keras.layers.Dense(10)(inputs) 992 outputs = tf.keras.layers.Dense(1)(x) 993 model = tf.keras.Model(inputs, outputs) 994 # Activity regularization. 995 model.add_loss(tf.abs(tf.reduce_mean(x))) 996 ``` 997 998 If this is not the case for your loss (if, for example, your loss references 999 a `Variable` of one of the model's layers), you can wrap your loss in a 1000 zero-argument lambda. These losses are not tracked as part of the model's 1001 topology since they can't be serialized. 1002 1003 Example: 1004 1005 ```python 1006 inputs = tf.keras.Input(shape=(10,)) 1007 x = tf.keras.layers.Dense(10)(inputs) 1008 outputs = tf.keras.layers.Dense(1)(x) 1009 model = tf.keras.Model(inputs, outputs) 1010 # Weight regularization. 1011 model.add_loss(lambda: tf.reduce_mean(x.kernel)) 1012 ``` 1013 1014 The `get_losses_for` method allows to retrieve the losses relevant to a 1015 specific set of inputs. 1016 1017 Args: 1018 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses 1019 may also be zero-argument callables which create a loss tensor. 1020 inputs: Ignored when executing eagerly. If anything other than None is 1021 passed, it signals the losses are conditional on some of the layer's 1022 inputs, and thus they should only be run where these inputs are 1023 available. This is the case for activity regularization losses, for 1024 instance. If `None` is passed, the losses are assumed 1025 to be unconditional, and will apply across all dataflows of the layer 1026 (e.g. weight regularization losses). 1027 """ 1028 def _tag_unconditional(loss): 1029 """Process the loss and tag it by setting loss._unconditional_loss.""" 1030 if callable(loss): 1031 # We run the loss without autocasting, as regularizers are often 1032 # numerically unstable in float16. 1033 with autocast_variable.enable_auto_cast_variables(None): 1034 loss = loss() 1035 if loss is None: 1036 return None # Will be filtered out when computing the .losses property 1037 if not tensor_util.is_tf_type(loss): 1038 loss = ops.convert_to_tensor_v2_with_dispatch( 1039 loss, dtype=backend.floatx()) 1040 loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access 1041 return loss 1042 1043 losses = nest.flatten(losses) 1044 1045 callable_losses = [] 1046 symbolic_losses = [] 1047 for loss in losses: 1048 if callable(loss): 1049 callable_losses.append(functools.partial(_tag_unconditional, loss)) 1050 continue 1051 if loss is None: 1052 continue 1053 if not tensor_util.is_tf_type(loss): 1054 loss = ops.convert_to_tensor_v2_with_dispatch( 1055 loss, dtype=backend.floatx()) 1056 # TF Functions should take the eager path. 1057 if (tf_utils.is_symbolic_tensor(loss) and 1058 not base_layer_utils.is_in_tf_function()): 1059 symbolic_losses.append(_tag_unconditional(loss)) 1060 base_layer_utils.check_graph_consistency(loss, method='add_loss') 1061 1062 self._callable_losses.extend(callable_losses) 1063 1064 in_call_context = base_layer_utils.call_context().in_call 1065 1066 if in_call_context: 1067 for symbolic_loss in symbolic_losses: 1068 self._losses.append(symbolic_loss) 1069 else: 1070 for symbolic_loss in symbolic_losses: 1071 if getattr(self, '_is_graph_network', False): 1072 self._graph_network_add_loss(symbolic_loss) 1073 else: 1074 # Possible a loss was added in a Layer's `build`. 1075 self._losses.append(symbolic_loss) 1076 1077 @property 1078 def metrics(self): 1079 collected_metrics = [] 1080 for layer in self._flatten_layers(): 1081 collected_metrics.extend(layer._metrics) 1082 return collected_metrics 1083 1084 @doc_controls.for_subclass_implementers 1085 def add_metric(self, value, aggregation=None, name=None): 1086 """Adds metric tensor to the layer. 1087 1088 Args: 1089 value: Metric tensor. 1090 aggregation: Sample-wise metric reduction function. If `aggregation=None`, 1091 it indicates that the metric tensor provided has been aggregated 1092 already. eg, `bin_acc = BinaryAccuracy(name='acc')` followed by 1093 `model.add_metric(bin_acc(y_true, y_pred))`. If aggregation='mean', the 1094 given metric tensor will be sample-wise reduced using `mean` function. 1095 eg, `model.add_metric(tf.reduce_sum(outputs), name='output_mean', 1096 aggregation='mean')`. 1097 name: String metric name. 1098 1099 Raises: 1100 ValueError: If `aggregation` is anything other than None or `mean`. 1101 """ 1102 if aggregation is not None and aggregation != 'mean': 1103 raise ValueError( 1104 'We currently support only `mean` sample-wise metric aggregation. ' 1105 'You provided aggregation=`%s`' % aggregation) 1106 1107 from_metric_obj = hasattr(value, '_metric_obj') 1108 is_symbolic = tf_utils.is_symbolic_tensor(value) 1109 in_call_context = base_layer_utils.call_context().in_call 1110 1111 if name is None and not from_metric_obj: 1112 # Eg. `self.add_metric(math_ops.reduce_sum(x), aggregation='mean')` 1113 # In eager mode, we use metric name to lookup a metric. Without a name, 1114 # a new Mean metric wrapper will be created on every model/layer call. 1115 # So, we raise an error when no name is provided. 1116 # We will do the same for symbolic mode for consistency although a name 1117 # will be generated if no name is provided. 1118 1119 # We will not raise this error in the foll use case for the sake of 1120 # consistency as name in provided in the metric constructor. 1121 # mean = metrics.Mean(name='my_metric') 1122 # model.add_metric(mean(outputs)) 1123 raise ValueError('Please provide a name for your metric like ' 1124 '`self.add_metric(tf.reduce_sum(inputs), ' 1125 'name=\'mean_activation\', aggregation=\'mean\')`') 1126 elif from_metric_obj: 1127 name = value._metric_obj.name 1128 1129 if in_call_context: 1130 # TF Function path should take the eager path. 1131 self._symbolic_add_metric(value, aggregation, name) 1132 else: 1133 if not is_symbolic: 1134 raise ValueError('Expected a symbolic Tensor for the metric value, ' 1135 'received: ' + str(value)) 1136 1137 # Possible a metric was added in a Layer's `build`. 1138 if not getattr(self, '_is_graph_network', False): 1139 with backend.get_graph().as_default(): 1140 self._symbolic_add_metric(value, aggregation, name) 1141 return 1142 1143 if from_metric_obj: 1144 raise ValueError('Using the result of calling a `Metric` object ' 1145 'when calling `add_metric` on a Functional ' 1146 'Model is not supported. Please pass the ' 1147 'Tensor to monitor directly.') 1148 1149 # Insert layers into the Keras Graph Network. 1150 self._graph_network_add_metric(value, aggregation, name) 1151 1152 @doc_controls.for_subclass_implementers 1153 def add_update(self, updates, inputs=None): 1154 """Add update op(s), potentially dependent on layer inputs. 1155 1156 Weight updates (for instance, the updates of the moving mean and variance 1157 in a BatchNormalization layer) may be dependent on the inputs passed 1158 when calling a layer. Hence, when reusing the same layer on 1159 different inputs `a` and `b`, some entries in `layer.updates` may be 1160 dependent on `a` and some on `b`. This method automatically keeps track 1161 of dependencies. 1162 1163 The `get_updates_for` method allows to retrieve the updates relevant to a 1164 specific set of inputs. 1165 1166 This call is ignored when eager execution is enabled (in that case, variable 1167 updates are run on the fly and thus do not need to be tracked for later 1168 execution). 1169 1170 Args: 1171 updates: Update op, or list/tuple of update ops, or zero-arg callable 1172 that returns an update op. A zero-arg callable should be passed in 1173 order to disable running the updates by setting `trainable=False` 1174 on this Layer, when executing in Eager mode. 1175 inputs: Deprecated, will be automatically inferred. 1176 """ 1177 if inputs is not None: 1178 tf_logging.warning( 1179 '`add_update` `inputs` kwarg has been deprecated. You no longer need ' 1180 'to pass a value to `inputs` as it is being automatically inferred.') 1181 call_context = base_layer_utils.call_context() 1182 1183 if (ds_context.has_strategy() and 1184 ds_context.in_cross_replica_context() and 1185 # When saving the model, the distribution strategy context should be 1186 # ignored, following the default path for adding updates. 1187 not call_context.saving): 1188 # Updates don't need to be run in a cross-replica context. 1189 return 1190 1191 updates = generic_utils.to_list(updates) 1192 1193 if call_context.in_call: 1194 relevant_inputs = call_context.inputs 1195 else: 1196 inbound_nodes = getattr(self, '_inbound_nodes', []) 1197 relevant_inputs = [node.input_tensors for node in inbound_nodes] 1198 1199 def process_update(x): 1200 """Standardize update ops. 1201 1202 Args: 1203 x: Tensor, op, or callable. 1204 1205 Returns: 1206 An update op. 1207 """ 1208 if callable(x): 1209 update = lambda: process_update(x()) 1210 return update() 1211 elif isinstance(x, ops.Operation): 1212 update = x 1213 elif hasattr(x, 'op'): 1214 update = x.op 1215 else: 1216 update = ops.convert_to_tensor_v2_with_dispatch(x) 1217 1218 reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, [update]) 1219 update._unconditional_update = update not in reachable 1220 return update 1221 1222 updates = [process_update(x) for x in updates] 1223 self._updates.extend(updates) 1224 1225 def set_weights(self, weights): 1226 """Sets the weights of the layer, from Numpy arrays. 1227 1228 The weights of a layer represent the state of the layer. This function 1229 sets the weight values from numpy arrays. The weight values should be 1230 passed in the order they are created by the layer. Note that the layer's 1231 weights must be instantiated before calling this function by calling 1232 the layer. 1233 1234 For example, a Dense layer returns a list of two values-- per-output 1235 weights and the bias value. These can be used to set the weights of another 1236 Dense layer: 1237 1238 >>> a = tf.keras.layers.Dense(1, 1239 ... kernel_initializer=tf.constant_initializer(1.)) 1240 >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]])) 1241 >>> a.get_weights() 1242 [array([[1.], 1243 [1.], 1244 [1.]], dtype=float32), array([0.], dtype=float32)] 1245 >>> b = tf.keras.layers.Dense(1, 1246 ... kernel_initializer=tf.constant_initializer(2.)) 1247 >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]])) 1248 >>> b.get_weights() 1249 [array([[2.], 1250 [2.], 1251 [2.]], dtype=float32), array([0.], dtype=float32)] 1252 >>> b.set_weights(a.get_weights()) 1253 >>> b.get_weights() 1254 [array([[1.], 1255 [1.], 1256 [1.]], dtype=float32), array([0.], dtype=float32)] 1257 1258 Args: 1259 weights: a list of Numpy arrays. The number 1260 of arrays and their shape must match 1261 number of the dimensions of the weights 1262 of the layer (i.e. it should match the 1263 output of `get_weights`). 1264 1265 Raises: 1266 ValueError: If the provided weights list does not match the 1267 layer's specifications. 1268 """ 1269 params = self.weights 1270 1271 expected_num_weights = 0 1272 for param in params: 1273 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1274 expected_num_weights += param.num_tensors 1275 else: 1276 expected_num_weights += 1 1277 1278 if expected_num_weights != len(weights): 1279 raise ValueError( 1280 'You called `set_weights(weights)` on layer "%s" ' 1281 'with a weight list of length %s, but the layer was ' 1282 'expecting %s weights. Provided weights: %s...' % 1283 (self.name, len(weights), expected_num_weights, str(weights)[:50])) 1284 1285 weight_index = 0 1286 weight_value_tuples = [] 1287 for param in params: 1288 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1289 num_tensors = param.num_tensors 1290 tensors = weights[weight_index:weight_index + num_tensors] 1291 param.set_weights(tensors) 1292 weight_index += num_tensors 1293 else: 1294 weight = weights[weight_index] 1295 ref_shape = param.shape 1296 if not ref_shape.is_compatible_with(weight.shape): 1297 raise ValueError( 1298 'Layer weight shape %s not compatible with provided weight ' 1299 'shape %s' % (ref_shape, weight.shape)) 1300 weight_value_tuples.append((param, weight)) 1301 weight_index += 1 1302 1303 backend.batch_set_value(weight_value_tuples) 1304 1305 def get_weights(self): 1306 """Returns the current weights of the layer. 1307 1308 The weights of a layer represent the state of the layer. This function 1309 returns both trainable and non-trainable weight values associated with this 1310 layer as a list of Numpy arrays, which can in turn be used to load state 1311 into similarly parameterized layers. 1312 1313 For example, a Dense layer returns a list of two values-- per-output 1314 weights and the bias value. These can be used to set the weights of another 1315 Dense layer: 1316 1317 >>> a = tf.keras.layers.Dense(1, 1318 ... kernel_initializer=tf.constant_initializer(1.)) 1319 >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]])) 1320 >>> a.get_weights() 1321 [array([[1.], 1322 [1.], 1323 [1.]], dtype=float32), array([0.], dtype=float32)] 1324 >>> b = tf.keras.layers.Dense(1, 1325 ... kernel_initializer=tf.constant_initializer(2.)) 1326 >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]])) 1327 >>> b.get_weights() 1328 [array([[2.], 1329 [2.], 1330 [2.]], dtype=float32), array([0.], dtype=float32)] 1331 >>> b.set_weights(a.get_weights()) 1332 >>> b.get_weights() 1333 [array([[1.], 1334 [1.], 1335 [1.]], dtype=float32), array([0.], dtype=float32)] 1336 1337 Returns: 1338 Weights values as a list of numpy arrays. 1339 """ 1340 weights = self.weights 1341 output_weights = [] 1342 for weight in weights: 1343 if isinstance(weight, base_layer_utils.TrackableWeightHandler): 1344 output_weights.extend(weight.get_tensors()) 1345 else: 1346 output_weights.append(weight) 1347 return backend.batch_get_value(output_weights) 1348 1349 def get_updates_for(self, inputs): 1350 """Retrieves updates relevant to a specific set of inputs. 1351 1352 Args: 1353 inputs: Input tensor or list/tuple of input tensors. 1354 1355 Returns: 1356 List of update ops of the layer that depend on `inputs`. 1357 """ 1358 if inputs is None: 1359 # Requesting unconditional updates. 1360 return [u for u in self.updates if u._unconditional_update] 1361 1362 # Requesting input-conditional updates. 1363 updates = [u for u in self.updates if not u._unconditional_update] 1364 inputs = nest.flatten(inputs) 1365 reachable = tf_utils.get_reachable_from_inputs(inputs, updates) 1366 return [u for u in updates if u in reachable] 1367 1368 def get_losses_for(self, inputs): 1369 """Retrieves losses relevant to a specific set of inputs. 1370 1371 Args: 1372 inputs: Input tensor or list/tuple of input tensors. 1373 1374 Returns: 1375 List of loss tensors of the layer that depend on `inputs`. 1376 """ 1377 if inputs is None: 1378 # Requesting unconditional losses. 1379 return [l for l in self.losses if l._unconditional_loss] 1380 1381 # Requesting input-conditional losses. 1382 losses = [l for l in self.losses if not l._unconditional_loss] 1383 inputs = nest.flatten(inputs) 1384 reachable = tf_utils.get_reachable_from_inputs(inputs, losses) 1385 return [l for l in losses if l in reachable] 1386 1387 def get_input_mask_at(self, node_index): 1388 """Retrieves the input mask tensor(s) of a layer at a given node. 1389 1390 Args: 1391 node_index: Integer, index of the node 1392 from which to retrieve the attribute. 1393 E.g. `node_index=0` will correspond to the 1394 first time the layer was called. 1395 1396 Returns: 1397 A mask tensor 1398 (or list of tensors if the layer has multiple inputs). 1399 """ 1400 inputs = self.get_input_at(node_index) 1401 if isinstance(inputs, list): 1402 return [getattr(x, '_keras_mask', None) for x in inputs] 1403 else: 1404 return getattr(inputs, '_keras_mask', None) 1405 1406 def get_output_mask_at(self, node_index): 1407 """Retrieves the output mask tensor(s) of a layer at a given node. 1408 1409 Args: 1410 node_index: Integer, index of the node 1411 from which to retrieve the attribute. 1412 E.g. `node_index=0` will correspond to the 1413 first time the layer was called. 1414 1415 Returns: 1416 A mask tensor 1417 (or list of tensors if the layer has multiple outputs). 1418 """ 1419 output = self.get_output_at(node_index) 1420 if isinstance(output, list): 1421 return [getattr(x, '_keras_mask', None) for x in output] 1422 else: 1423 return getattr(output, '_keras_mask', None) 1424 1425 @property 1426 def input_mask(self): 1427 """Retrieves the input mask tensor(s) of a layer. 1428 1429 Only applicable if the layer has exactly one inbound node, 1430 i.e. if it is connected to one incoming layer. 1431 1432 Returns: 1433 Input mask tensor (potentially None) or list of input 1434 mask tensors. 1435 1436 Raises: 1437 AttributeError: if the layer is connected to 1438 more than one incoming layers. 1439 """ 1440 inputs = self.input 1441 if isinstance(inputs, list): 1442 return [getattr(x, '_keras_mask', None) for x in inputs] 1443 else: 1444 return getattr(inputs, '_keras_mask', None) 1445 1446 @property 1447 def output_mask(self): 1448 """Retrieves the output mask tensor(s) of a layer. 1449 1450 Only applicable if the layer has exactly one inbound node, 1451 i.e. if it is connected to one incoming layer. 1452 1453 Returns: 1454 Output mask tensor (potentially None) or list of output 1455 mask tensors. 1456 1457 Raises: 1458 AttributeError: if the layer is connected to 1459 more than one incoming layers. 1460 """ 1461 output = self.output 1462 if isinstance(output, list): 1463 return [getattr(x, '_keras_mask', None) for x in output] 1464 else: 1465 return getattr(output, '_keras_mask', None) 1466 1467 def get_input_shape_at(self, node_index): 1468 """Retrieves the input shape(s) of a layer at a given node. 1469 1470 Args: 1471 node_index: Integer, index of the node 1472 from which to retrieve the attribute. 1473 E.g. `node_index=0` will correspond to the 1474 first time the layer was called. 1475 1476 Returns: 1477 A shape tuple 1478 (or list of shape tuples if the layer has multiple inputs). 1479 1480 Raises: 1481 RuntimeError: If called in Eager mode. 1482 """ 1483 return self._get_node_attribute_at_index(node_index, 'input_shapes', 1484 'input shape') 1485 1486 def get_output_shape_at(self, node_index): 1487 """Retrieves the output shape(s) of a layer at a given node. 1488 1489 Args: 1490 node_index: Integer, index of the node 1491 from which to retrieve the attribute. 1492 E.g. `node_index=0` will correspond to the 1493 first time the layer was called. 1494 1495 Returns: 1496 A shape tuple 1497 (or list of shape tuples if the layer has multiple outputs). 1498 1499 Raises: 1500 RuntimeError: If called in Eager mode. 1501 """ 1502 return self._get_node_attribute_at_index(node_index, 'output_shapes', 1503 'output shape') 1504 1505 def get_input_at(self, node_index): 1506 """Retrieves the input tensor(s) of a layer at a given node. 1507 1508 Args: 1509 node_index: Integer, index of the node 1510 from which to retrieve the attribute. 1511 E.g. `node_index=0` will correspond to the 1512 first input node of the layer. 1513 1514 Returns: 1515 A tensor (or list of tensors if the layer has multiple inputs). 1516 1517 Raises: 1518 RuntimeError: If called in Eager mode. 1519 """ 1520 return self._get_node_attribute_at_index(node_index, 'input_tensors', 1521 'input') 1522 1523 def get_output_at(self, node_index): 1524 """Retrieves the output tensor(s) of a layer at a given node. 1525 1526 Args: 1527 node_index: Integer, index of the node 1528 from which to retrieve the attribute. 1529 E.g. `node_index=0` will correspond to the 1530 first output node of the layer. 1531 1532 Returns: 1533 A tensor (or list of tensors if the layer has multiple outputs). 1534 1535 Raises: 1536 RuntimeError: If called in Eager mode. 1537 """ 1538 return self._get_node_attribute_at_index(node_index, 'output_tensors', 1539 'output') 1540 1541 @property 1542 def input(self): 1543 """Retrieves the input tensor(s) of a layer. 1544 1545 Only applicable if the layer has exactly one input, 1546 i.e. if it is connected to one incoming layer. 1547 1548 Returns: 1549 Input tensor or list of input tensors. 1550 1551 Raises: 1552 RuntimeError: If called in Eager mode. 1553 AttributeError: If no inbound nodes are found. 1554 """ 1555 if not self._inbound_nodes: 1556 raise AttributeError('Layer ' + self.name + 1557 ' is not connected, no input to return.') 1558 return self._get_node_attribute_at_index(0, 'input_tensors', 'input') 1559 1560 @property 1561 def output(self): 1562 """Retrieves the output tensor(s) of a layer. 1563 1564 Only applicable if the layer has exactly one output, 1565 i.e. if it is connected to one incoming layer. 1566 1567 Returns: 1568 Output tensor or list of output tensors. 1569 1570 Raises: 1571 AttributeError: if the layer is connected to more than one incoming 1572 layers. 1573 RuntimeError: if called in Eager mode. 1574 """ 1575 if not self._inbound_nodes: 1576 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') 1577 return self._get_node_attribute_at_index(0, 'output_tensors', 'output') 1578 1579 @property 1580 def input_shape(self): 1581 """Retrieves the input shape(s) of a layer. 1582 1583 Only applicable if the layer has exactly one input, 1584 i.e. if it is connected to one incoming layer, or if all inputs 1585 have the same shape. 1586 1587 Returns: 1588 Input shape, as an integer shape tuple 1589 (or list of shape tuples, one tuple per input tensor). 1590 1591 Raises: 1592 AttributeError: if the layer has no defined input_shape. 1593 RuntimeError: if called in Eager mode. 1594 """ 1595 if not self._inbound_nodes: 1596 raise AttributeError('The layer has never been called ' 1597 'and thus has no defined input shape.') 1598 all_input_shapes = set( 1599 [str(node.input_shapes) for node in self._inbound_nodes]) 1600 if len(all_input_shapes) == 1: 1601 return self._inbound_nodes[0].input_shapes 1602 else: 1603 raise AttributeError('The layer "' + str(self.name) + 1604 ' has multiple inbound nodes, ' 1605 'with different input shapes. Hence ' 1606 'the notion of "input shape" is ' 1607 'ill-defined for the layer. ' 1608 'Use `get_input_shape_at(node_index)` ' 1609 'instead.') 1610 1611 def count_params(self): 1612 """Count the total number of scalars composing the weights. 1613 1614 Returns: 1615 An integer count. 1616 1617 Raises: 1618 ValueError: if the layer isn't yet built 1619 (in which case its weights aren't yet defined). 1620 """ 1621 if not self.built: 1622 if getattr(self, '_is_graph_network', False): 1623 with tf_utils.maybe_init_scope(self): 1624 self._maybe_build(self.inputs) 1625 else: 1626 raise ValueError('You tried to call `count_params` on ' + self.name + 1627 ', but the layer isn\'t built. ' 1628 'You can build it manually via: `' + self.name + 1629 '.build(batch_input_shape)`.') 1630 return layer_utils.count_params(self.weights) 1631 1632 @property 1633 def output_shape(self): 1634 """Retrieves the output shape(s) of a layer. 1635 1636 Only applicable if the layer has one output, 1637 or if all outputs have the same shape. 1638 1639 Returns: 1640 Output shape, as an integer shape tuple 1641 (or list of shape tuples, one tuple per output tensor). 1642 1643 Raises: 1644 AttributeError: if the layer has no defined output shape. 1645 RuntimeError: if called in Eager mode. 1646 """ 1647 if not self._inbound_nodes: 1648 raise AttributeError('The layer has never been called ' 1649 'and thus has no defined output shape.') 1650 all_output_shapes = set( 1651 [str(node.output_shapes) for node in self._inbound_nodes]) 1652 if len(all_output_shapes) == 1: 1653 return self._inbound_nodes[0].output_shapes 1654 else: 1655 raise AttributeError('The layer "%s"' 1656 ' has multiple inbound nodes, ' 1657 'with different output shapes. Hence ' 1658 'the notion of "output shape" is ' 1659 'ill-defined for the layer. ' 1660 'Use `get_output_shape_at(node_index)` ' 1661 'instead.' % self.name) 1662 1663 @property 1664 @doc_controls.do_not_doc_inheritable 1665 def inbound_nodes(self): 1666 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 1667 return self._inbound_nodes 1668 1669 @property 1670 @doc_controls.do_not_doc_inheritable 1671 def outbound_nodes(self): 1672 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 1673 return self._outbound_nodes 1674 1675 ############################################################################## 1676 # Methods & attributes below are public aliases of other methods. # 1677 ############################################################################## 1678 1679 @doc_controls.do_not_doc_inheritable 1680 def apply(self, inputs, *args, **kwargs): 1681 """Deprecated, do NOT use! 1682 1683 This is an alias of `self.__call__`. 1684 1685 Args: 1686 inputs: Input tensor(s). 1687 *args: additional positional arguments to be passed to `self.call`. 1688 **kwargs: additional keyword arguments to be passed to `self.call`. 1689 1690 Returns: 1691 Output tensor(s). 1692 """ 1693 warnings.warn('`layer.apply` is deprecated and ' 1694 'will be removed in a future version. ' 1695 'Please use `layer.__call__` method instead.') 1696 return self.__call__(inputs, *args, **kwargs) 1697 1698 @doc_controls.do_not_doc_inheritable 1699 def add_variable(self, *args, **kwargs): 1700 """Deprecated, do NOT use! Alias for `add_weight`.""" 1701 warnings.warn('`layer.add_variable` is deprecated and ' 1702 'will be removed in a future version. ' 1703 'Please use `layer.add_weight` method instead.') 1704 return self.add_weight(*args, **kwargs) 1705 1706 @property 1707 def variables(self): 1708 """Returns the list of all layer variables/weights. 1709 1710 Alias of `self.weights`. 1711 1712 Returns: 1713 A list of variables. 1714 """ 1715 return self.weights 1716 1717 @property 1718 def trainable_variables(self): 1719 return self.trainable_weights 1720 1721 @property 1722 def non_trainable_variables(self): 1723 return self.non_trainable_weights 1724 1725 ############################################################################## 1726 # Methods & attributes below are all private and only used by the framework. # 1727 ############################################################################## 1728 1729 @property 1730 def _inbound_nodes(self): 1731 return self._inbound_nodes_value 1732 1733 @_inbound_nodes.setter 1734 @trackable.no_automatic_dependency_tracking 1735 def _inbound_nodes(self, value): 1736 self._inbound_nodes_value = value 1737 1738 @property 1739 def _outbound_nodes(self): 1740 return self._outbound_nodes_value 1741 1742 @_outbound_nodes.setter 1743 @trackable.no_automatic_dependency_tracking 1744 def _outbound_nodes(self, value): 1745 self._outbound_nodes_value = value 1746 1747 def _set_dtype_policy(self, dtype): 1748 """Sets self._dtype_policy.""" 1749 if isinstance(dtype, policy.Policy): 1750 self._dtype_policy = dtype 1751 elif isinstance(dtype, dict): 1752 self._dtype_policy = policy.deserialize(dtype) 1753 elif dtype: 1754 self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name) 1755 else: 1756 self._dtype_policy = policy.global_policy() 1757 if (self._dtype_policy.name == 'mixed_float16' and 1758 not loss_scale_optimizer.strategy_supports_loss_scaling()): 1759 # Although only loss scaling doesn't support certain strategies, to avoid 1760 # confusion, we disallow the 'mixed_float16' policy with unsupported 1761 # strategies. This is because 'mixed_float16' requires loss scaling for 1762 # numeric stability. 1763 strategy = ds_context.get_strategy() 1764 raise ValueError('Mixed precision is not supported with the ' 1765 'tf.distribute.Strategy: %s. Either stop using mixed ' 1766 'precision by removing the use of the "%s" policy or ' 1767 'use a different Strategy, e.g. a MirroredStrategy.' % 1768 (strategy.__class__.__name__, self._dtype_policy.name)) 1769 1770 # Performance optimization: cache the compute dtype as a Dtype object or 1771 # None, so that str to Dtype conversion doesn't happen in Layer.__call__. 1772 if self._dtype_policy.compute_dtype: 1773 self._compute_dtype_object = dtypes.as_dtype( 1774 self._dtype_policy.compute_dtype) 1775 else: 1776 self._compute_dtype_object = None 1777 1778 # TODO(reedwm): Expose this property? 1779 @property 1780 def _compute_dtype(self): 1781 """The layer's compute dtype. 1782 1783 Unless mixed-precision is used, this is the same as `Layer.dtype`. 1784 1785 If self._autocast is True, layer's will cast floating-point inputs to this. 1786 1787 Returns: 1788 The layer's compute dtype. 1789 """ 1790 return self._dtype_policy.compute_dtype 1791 1792 def _maybe_cast_inputs(self, inputs): 1793 """Maybe casts the inputs to the compute dtype. 1794 1795 If self._compute_dtype is floating-point, and self_autocast is True, 1796 floating-point inputs are casted to self._compute_dtype. 1797 1798 Args: 1799 inputs: Input tensor, or structure of input tensors. 1800 1801 Returns: 1802 `inputs`, but tensors may have been casted to self._compute_dtype 1803 """ 1804 compute_dtype = self._compute_dtype 1805 if (self._autocast and compute_dtype and 1806 dtypes.as_dtype(compute_dtype).is_floating): 1807 def f(x): 1808 """Cast a single Tensor or TensorSpec to the compute dtype.""" 1809 cast_types = (ops.Tensor, sparse_tensor.SparseTensor, 1810 ragged_tensor.RaggedTensor) 1811 if (isinstance(x, cast_types) and x.dtype.is_floating and 1812 x.dtype.base_dtype.name != compute_dtype): 1813 return math_ops.cast(x, compute_dtype) 1814 elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating: 1815 # Inputs may be TensorSpecs when this function is called from 1816 # model._set_inputs. 1817 return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name) 1818 else: 1819 return x 1820 return nest.map_structure(f, inputs) 1821 else: 1822 return inputs 1823 1824 # _dtype used to be an attribute set in the constructor. We still expose it 1825 # because some clients still use it. 1826 # TODO(reedwm): Deprecate, then remove the _dtype property. 1827 @property 1828 def _dtype(self): 1829 # This is equivalent to returning self.dtype . We do not return self.dtype 1830 # as it would cause infinite recursion in a few subclasses, which override 1831 # "dtype" to return self._dtype. 1832 return self._dtype_policy.variable_dtype 1833 1834 @_dtype.setter 1835 def _dtype(self, value): 1836 value = dtypes.as_dtype(value).name 1837 self._set_dtype_policy(policy.Policy(value)) 1838 1839 def _name_scope(self): 1840 return self.name 1841 1842 def _init_set_name(self, name, zero_based=True): 1843 if not name: 1844 self._name = backend.unique_object_name( 1845 generic_utils.to_snake_case(self.__class__.__name__), 1846 zero_based=zero_based) 1847 else: 1848 self._name = name 1849 1850 def _get_existing_metric(self, name=None): 1851 match = [m for m in self._metrics if m.name == name] 1852 if not match: 1853 return 1854 if len(match) > 1: 1855 raise ValueError( 1856 'Please provide different names for the metrics you have added. ' 1857 'We found {} metrics with the name: "{}"'.format(len(match), name)) 1858 return match[0] 1859 1860 def _symbolic_add_metric(self, value, aggregation=None, name=None): 1861 base_layer_utils.check_graph_consistency(value, method='add_metric') 1862 match = self._get_existing_metric(name) 1863 if aggregation is None: 1864 # Iterate over the metrics and check if the given metric exists already. 1865 # This can happen when a metric instance is created in subclassed model 1866 # layer `__init__` and we have tracked that instance already in 1867 # model.__setattr__. 1868 if match: 1869 result_tensor = value 1870 metric_obj = match 1871 elif hasattr(value, '_metric_obj'): 1872 # We track the instance using the metadata on the result tensor. 1873 result_tensor = value 1874 metric_obj = result_tensor._metric_obj 1875 self._metrics.append(metric_obj) 1876 else: 1877 raise ValueError( 1878 'We do not support adding an aggregated metric result tensor that ' 1879 'is not the output of a `tf.keras.metrics.Metric` metric instance. ' 1880 'Without having access to the metric instance we cannot reset the ' 1881 'state of a metric after every epoch during training. You can ' 1882 'create a `tf.keras.metrics.Metric` instance and pass the result ' 1883 'here or pass an un-aggregated result with `aggregation` parameter ' 1884 'set as `mean`. For example: `self.add_metric(tf.reduce_sum(inputs)' 1885 ', name=\'mean_activation\', aggregation=\'mean\')`') 1886 else: 1887 # If a non-aggregated tensor is given as input (ie. `aggregation` is 1888 # explicitly set to `mean`), we wrap the tensor in `Mean` metric. 1889 if match: 1890 result_tensor = match(value) 1891 metric_obj = match 1892 else: 1893 metric_obj, result_tensor = base_layer_utils.create_mean_metric( 1894 value, name) 1895 self._metrics.append(metric_obj) 1896 1897 def _handle_weight_regularization(self, name, variable, regularizer): 1898 """Create lambdas which compute regularization losses.""" 1899 1900 def _loss_for_variable(v): 1901 """Creates a regularization loss `Tensor` for variable `v`.""" 1902 with backend.name_scope(name + '/Regularizer'): 1903 regularization = regularizer(v) 1904 return regularization 1905 1906 if base_layer_utils.is_split_variable(variable): 1907 for v in variable: 1908 self.add_loss(functools.partial(_loss_for_variable, v)) 1909 else: 1910 self.add_loss(functools.partial(_loss_for_variable, variable)) 1911 1912 def _handle_activity_regularization(self, inputs, outputs): 1913 # Apply activity regularization. 1914 # Note that it should be applied every time the layer creates a new 1915 # output, since it is output-specific. 1916 if self._activity_regularizer: 1917 output_list = nest.flatten(outputs) 1918 with backend.name_scope('ActivityRegularizer'): 1919 for output in output_list: 1920 activity_loss = self._activity_regularizer(output) 1921 batch_size = math_ops.cast( 1922 array_ops.shape(output)[0], activity_loss.dtype) 1923 # Make activity regularization strength batch-agnostic. 1924 mean_activity_loss = activity_loss / batch_size 1925 base_layer_utils.check_graph_consistency( 1926 mean_activity_loss, method='activity_regularizer') 1927 self.add_loss(mean_activity_loss, inputs=inputs) 1928 1929 def _set_mask_metadata(self, inputs, outputs, previous_mask): 1930 flat_outputs = nest.flatten(outputs) 1931 1932 mask_already_computed = ( 1933 getattr(self, '_compute_output_and_mask_jointly', False) or 1934 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) 1935 1936 # Only compute the mask if the Layer explicitly supports masking or has 1937 # overridden `compute_mask`. 1938 should_compute_mask = ( 1939 hasattr(self, 'compute_mask') and 1940 (self.supports_masking or 1941 not getattr(self.compute_mask, '_is_default', False))) 1942 1943 if mask_already_computed: 1944 flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs] 1945 elif not should_compute_mask: 1946 flat_masks = [None for _ in flat_outputs] 1947 else: 1948 output_masks = self.compute_mask(inputs, previous_mask) 1949 # `compute_mask` can return a single `None` even when a Layer 1950 # has multiple outputs. 1951 if output_masks is None: 1952 flat_masks = [None for _ in flat_outputs] 1953 else: 1954 flat_masks = nest.flatten(output_masks) 1955 1956 for output, mask in zip(flat_outputs, flat_masks): 1957 try: 1958 output._keras_mask = mask 1959 except AttributeError: 1960 # C Type such as np.ndarray. 1961 pass 1962 1963 if tf_utils.are_all_symbolic_tensors(flat_outputs): 1964 for output in flat_outputs: 1965 if getattr(output, '_keras_mask', None) is not None: 1966 # Do not track masks for `TensorFlowOpLayer` construction. 1967 output._keras_mask._keras_history_checked = True 1968 1969 def _collect_input_masks(self, inputs, args, kwargs): 1970 """Checks if `mask` argument was passed, else gathers mask from inputs.""" 1971 if self._call_arg_was_passed('mask', args, kwargs): 1972 return self._get_call_arg_value('mask', args, kwargs) 1973 1974 if not self._should_compute_mask: 1975 return None 1976 1977 input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None), 1978 inputs) 1979 if generic_utils.is_all_none(input_masks): 1980 return None 1981 return input_masks 1982 1983 def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): 1984 if arg_name in kwargs: 1985 return True 1986 call_fn_args = self._call_fn_args 1987 if not inputs_in_args: 1988 # Ignore `inputs` arg. 1989 call_fn_args = call_fn_args[1:] 1990 if arg_name in dict(zip(call_fn_args, args)): 1991 return True 1992 return False 1993 1994 def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): 1995 if arg_name in kwargs: 1996 return kwargs[arg_name] 1997 call_fn_args = self._call_fn_args 1998 if not inputs_in_args: 1999 # Ignore `inputs` arg. 2000 call_fn_args = call_fn_args[1:] 2001 args_dict = dict(zip(call_fn_args, args)) 2002 return args_dict[arg_name] 2003 2004 def _set_call_arg_value( 2005 self, arg_name, new_value, args, 2006 kwargs, inputs_in_args=False, pop_kwarg_if_none=False): 2007 arg_pos = self._call_fn_arg_positions.get(arg_name, None) 2008 if arg_pos is not None: 2009 if not inputs_in_args: 2010 # Ignore `inputs` arg. 2011 arg_pos = arg_pos - 1 2012 if len(args) > arg_pos: 2013 args = list(args) 2014 args[arg_pos] = new_value 2015 return args, kwargs 2016 if new_value is None and pop_kwarg_if_none: 2017 kwargs.pop(arg_name, None) 2018 else: 2019 kwargs[arg_name] = new_value 2020 return args, kwargs 2021 2022 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 2023 """Private utility to retrieves an attribute (e.g. inputs) from a node. 2024 2025 This is used to implement the methods: 2026 - get_input_shape_at 2027 - get_output_shape_at 2028 - get_input_at 2029 etc... 2030 2031 Args: 2032 node_index: Integer index of the node from which 2033 to retrieve the attribute. 2034 attr: Exact node attribute name. 2035 attr_name: Human-readable attribute name, for error messages. 2036 2037 Returns: 2038 The layer's attribute `attr` at the node of index `node_index`. 2039 2040 Raises: 2041 RuntimeError: If the layer has no inbound nodes, or if called in Eager 2042 mode. 2043 ValueError: If the index provided does not match any node. 2044 """ 2045 if not self._inbound_nodes: 2046 raise RuntimeError('The layer has never been called ' 2047 'and thus has no defined ' + attr_name + '.') 2048 if not len(self._inbound_nodes) > node_index: 2049 raise ValueError('Asked to get ' + attr_name + ' at node ' + 2050 str(node_index) + ', but the layer has only ' + 2051 str(len(self._inbound_nodes)) + ' inbound nodes.') 2052 values = getattr(self._inbound_nodes[node_index], attr) 2053 if isinstance(values, list) and len(values) == 1: 2054 return values[0] 2055 else: 2056 return values 2057 2058 def _maybe_build(self, inputs): 2059 # Check input assumptions set before layer building, e.g. input rank. 2060 if not self.built: 2061 input_spec.assert_input_compatibility( 2062 self.input_spec, inputs, self.name) 2063 input_list = nest.flatten(inputs) 2064 if input_list and self._dtype_policy.compute_dtype is None: 2065 try: 2066 dtype = input_list[0].dtype.base_dtype.name 2067 except AttributeError: 2068 pass 2069 else: 2070 self._set_dtype_policy(policy.Policy(dtype)) 2071 input_shapes = None 2072 if all(hasattr(x, 'shape') for x in input_list): 2073 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 2074 # Only call `build` if the user has manually overridden the build method. 2075 if not hasattr(self.build, '_is_default'): 2076 # Any setup work performed only once should happen in an `init_scope` 2077 # to avoid creating symbolic Tensors that will later pollute any eager 2078 # operations. 2079 with tf_utils.maybe_init_scope(self): 2080 self.build(input_shapes) 2081 # We must set also ensure that the layer is marked as built, and the build 2082 # shape is stored since user defined build functions may not be calling 2083 # `super.build()` 2084 Layer.build(self, input_shapes) 2085 2086 # Optionally load weight values specified at layer instantiation. 2087 if self._initial_weights is not None: 2088 self.set_weights(self._initial_weights) 2089 self._initial_weights = None 2090 2091 def _symbolic_call(self, inputs): 2092 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 2093 output_shapes = self.compute_output_shape(input_shapes) 2094 2095 def _make_placeholder_like(shape): 2096 ph = backend.placeholder(shape=shape, dtype=self.dtype) 2097 ph._keras_mask = None 2098 return ph 2099 2100 return nest.map_structure(_make_placeholder_like, output_shapes) 2101 2102 def _get_trainable_state(self): 2103 """Get the `trainable` state of each sublayer. 2104 2105 Returns: 2106 A dict mapping all sublayers to their `trainable` value. 2107 """ 2108 layers = self._flatten_layers(include_self=False, recursive=False) 2109 trainable_state = {self: self.trainable} 2110 for l in layers: 2111 trainable_state.update(l._get_trainable_state()) 2112 return trainable_state 2113 2114 def _set_trainable_state(self, trainable_state): 2115 """Set `trainable` state for each sublayer.""" 2116 if self in trainable_state: 2117 self.trainable = trainable_state[self] 2118 layers = self._flatten_layers(include_self=False, recursive=False) 2119 for l in layers: 2120 if l in trainable_state: 2121 l._set_trainable_state(trainable_state) 2122 2123 @property 2124 def _obj_reference_counts(self): 2125 """A dictionary counting the number of attributes referencing an object.""" 2126 self._maybe_create_attribute('_obj_reference_counts_dict', 2127 object_identity.ObjectIdentityDictionary()) 2128 return self._obj_reference_counts_dict 2129 2130 @trackable.no_automatic_dependency_tracking 2131 def _maybe_create_attribute(self, name, default_value): 2132 """Create the attribute with the default value if it hasn't been created. 2133 2134 This is useful for fields that is used for tracking purpose, 2135 _trainable_weights, or _layers. Note that user could create a layer subclass 2136 and assign an internal field before invoking the Layer.__init__(), the 2137 __setattr__() need to create the tracking fields and __init__() need to not 2138 override them. 2139 2140 Args: 2141 name: String, the name of the attribute. 2142 default_value: Object, the default value of the attribute. 2143 """ 2144 if not hasattr(self, name): 2145 self.__setattr__(name, default_value) 2146 2147 def __delattr__(self, name): 2148 # For any super.__delattr__() call, we will directly use the implementation 2149 # in Trackable and skip the behavior in AutoTrackable. The Layer was 2150 # originally use Trackable as base class, the change of using Module as base 2151 # class forced us to have AutoTrackable in the class hierarchy. Skipping 2152 # the __delattr__ and __setattr__ in AutoTrackable will keep the status quo. 2153 existing_value = getattr(self, name, None) 2154 2155 # If this value is replacing an existing object assigned to an attribute, we 2156 # should clean it out to avoid leaking memory. First we check if there are 2157 # other attributes referencing it. 2158 reference_counts = self._obj_reference_counts 2159 if existing_value not in reference_counts: 2160 super(tracking.AutoTrackable, self).__delattr__(name) 2161 return 2162 2163 reference_count = reference_counts[existing_value] 2164 if reference_count > 1: 2165 # There are other remaining references. We can't remove this object from 2166 # _layers etc. 2167 reference_counts[existing_value] = reference_count - 1 2168 super(tracking.AutoTrackable, self).__delattr__(name) 2169 return 2170 else: 2171 # This is the last remaining reference. 2172 del reference_counts[existing_value] 2173 2174 super(tracking.AutoTrackable, self).__delattr__(name) 2175 2176 if (isinstance(existing_value, Layer) 2177 or base_layer_utils.has_weights(existing_value)): 2178 super(tracking.AutoTrackable, self).__setattr__( 2179 '_self_tracked_trackables', 2180 [l for l in self._self_tracked_trackables if l is not existing_value]) 2181 if isinstance(existing_value, tf_variables.Variable): 2182 super(tracking.AutoTrackable, self).__setattr__( 2183 '_trainable_weights', 2184 [w for w in self._trainable_weights if w is not existing_value]) 2185 super(tracking.AutoTrackable, self).__setattr__( 2186 '_non_trainable_weights', 2187 [w for w in self._non_trainable_weights if w is not existing_value]) 2188 2189 def __setattr__(self, name, value): 2190 if (name == '_self_setattr_tracking' or 2191 not getattr(self, '_self_setattr_tracking', True) or 2192 # Exclude @property.setters from tracking 2193 hasattr(self.__class__, name)): 2194 try: 2195 super(tracking.AutoTrackable, self).__setattr__(name, value) 2196 except AttributeError: 2197 raise AttributeError( 2198 ('Can\'t set the attribute "{}", likely because it conflicts with ' 2199 'an existing read-only @property of the object. Please choose a ' 2200 'different name.').format(name)) 2201 return 2202 2203 # Keep track of trackable objects, for the needs of `Network.save_weights`. 2204 value = data_structures.sticky_attribute_assignment( 2205 trackable=self, value=value, name=name) 2206 2207 reference_counts = self._obj_reference_counts 2208 reference_counts[value] = reference_counts.get(value, 0) + 1 2209 2210 # Clean out the old attribute, which clears _layers and _trainable_weights 2211 # if necessary. 2212 try: 2213 self.__delattr__(name) 2214 except AttributeError: 2215 pass 2216 2217 # Keep track of metric instance created in subclassed layer. 2218 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 2219 for val in nest.flatten(value): 2220 if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): 2221 self._metrics.append(val) 2222 2223 # TODO(scottzhu): Need to track Module object as well for weight tracking. 2224 # Be careful about metric if it becomes a Module in future. 2225 # Append value to self._layers if relevant 2226 if (getattr(self, '_auto_track_sub_layers', True) and 2227 (isinstance(value, Layer) or base_layer_utils.has_weights(value))): 2228 self._maybe_create_attribute('_self_tracked_trackables', []) 2229 # We need to check object identity to avoid de-duplicating empty 2230 # container types which compare equal. 2231 if not any((layer is value for layer in self._self_tracked_trackables)): 2232 self._self_tracked_trackables.append(value) 2233 if hasattr(value, '_use_resource_variables'): 2234 # Legacy layers (V1 tf.layers) must always use 2235 # resource variables. 2236 value._use_resource_variables = True 2237 2238 # Append value to list of trainable / non-trainable weights if relevant 2239 # TODO(b/125122625): This won't pick up on any variables added to a 2240 # list/dict after creation. 2241 for val in nest.flatten(value): 2242 if not isinstance(val, tf_variables.Variable): 2243 continue 2244 2245 # Users may add extra weights/variables 2246 # simply by assigning them to attributes (invalid for graph networks) 2247 self._maybe_create_attribute('_trainable_weights', []) 2248 self._maybe_create_attribute('_non_trainable_weights', []) 2249 if val.trainable: 2250 if any(val is w for w in self._trainable_weights): 2251 continue 2252 self._trainable_weights.append(val) 2253 else: 2254 if any(val is w for w in self._non_trainable_weights): 2255 continue 2256 self._non_trainable_weights.append(val) 2257 2258 backend.track_variable(val) 2259 2260 # Skip the auto trackable from tf.Module to keep status quo. See the comment 2261 # at __delattr__. 2262 super(tracking.AutoTrackable, self).__setattr__(name, value) 2263 2264 # This is a hack so that the is_layer (within 2265 # training/trackable/layer_utils.py) check doesn't get the weights attr. 2266 # TODO(b/110718070): Remove when fixed. 2267 def _is_layer(self): 2268 return True 2269 2270 def _init_call_fn_args(self): 2271 # Clear cached call function arguments. 2272 self.__class__._call_full_argspec.fget.cache.pop(self, None) 2273 self.__class__._call_fn_args.fget.cache.pop(self, None) 2274 self.__class__._call_accepts_kwargs.fget.cache.pop(self, None) 2275 2276 call_fn_args = self._call_fn_args 2277 self._expects_training_arg = ('training' in call_fn_args or 2278 self._call_accepts_kwargs) 2279 self._expects_mask_arg = ('mask' in call_fn_args or 2280 self._call_accepts_kwargs) 2281 2282 @property 2283 @layer_utils.cached_per_instance 2284 def _call_full_argspec(self): 2285 # Argspec inspection is expensive and the call spec is used often, so it 2286 # makes sense to cache the result. 2287 return tf_inspect.getfullargspec(self.call) 2288 2289 @property 2290 @layer_utils.cached_per_instance 2291 def _call_fn_args(self): 2292 all_args = self._call_full_argspec.args 2293 # Scrub `self` that appears if a decorator was applied. 2294 if all_args and all_args[0] == 'self': 2295 return all_args[1:] 2296 return all_args 2297 2298 @property 2299 @layer_utils.cached_per_instance 2300 def _call_fn_arg_positions(self): 2301 call_fn_arg_positions = dict() 2302 for pos, arg in enumerate(self._call_fn_args): 2303 call_fn_arg_positions[arg] = pos 2304 return call_fn_arg_positions 2305 2306 @property 2307 @layer_utils.cached_per_instance 2308 def _call_accepts_kwargs(self): 2309 return self._call_full_argspec.varkw is not None 2310 2311 @property 2312 @layer_utils.cached_per_instance 2313 def _should_compute_mask(self): 2314 return ('mask' in self._call_fn_args or 2315 getattr(self, 'compute_mask', None) is not None) 2316 2317 def _dedup_weights(self, weights): 2318 """Dedupe weights while maintaining order as much as possible.""" 2319 output, seen_ids = [], set() 2320 for w in weights: 2321 if id(w) not in seen_ids: 2322 output.append(w) 2323 # Track the Variable's identity to avoid __eq__ issues. 2324 seen_ids.add(id(w)) 2325 2326 return output 2327 2328 # SavedModel properties. Please see keras/saving/saved_model for details. 2329 2330 @property 2331 def _trackable_saved_model_saver(self): 2332 return layer_serialization.LayerSavedModelSaver(self) 2333 2334 @property 2335 def _object_identifier(self): 2336 return self._trackable_saved_model_saver.object_identifier 2337 2338 @property 2339 def _tracking_metadata(self): 2340 return self._trackable_saved_model_saver.tracking_metadata 2341 2342 def _list_extra_dependencies_for_serialization(self, serialization_cache): 2343 return (self._trackable_saved_model_saver 2344 .list_extra_dependencies_for_serialization(serialization_cache)) 2345 2346 def _list_functions_for_serialization(self, serialization_cache): 2347 return (self._trackable_saved_model_saver 2348 .list_functions_for_serialization(serialization_cache)) 2349 2350 def __getstate__(self): 2351 # Override to support `copy.deepcopy` and pickling. 2352 # Thread-local objects cannot be copied in Python 3, so pop these. 2353 # Thread-local objects are used to cache losses in MirroredStrategy, and 2354 # so shouldn't be copied. 2355 state = self.__dict__.copy() 2356 state.pop('_thread_local', None) 2357 return state 2358 2359 def __setstate__(self, state): 2360 state['_thread_local'] = threading.local() 2361 # Bypass Trackable logic as `__dict__` already contains this info. 2362 object.__setattr__(self, '__dict__', state) 2363 2364 2365class KerasHistory( 2366 collections.namedtuple('KerasHistory', 2367 ['layer', 'node_index', 'tensor_index'])): 2368 """Tracks the Layer call that created a Tensor, for Keras Graph Networks. 2369 2370 During construction of Keras Graph Networks, this metadata is added to 2371 each Tensor produced as the output of a Layer, starting with an 2372 `InputLayer`. This allows Keras to track how each Tensor was produced, and 2373 this information is later retraced by the `keras.engine.Network` class to 2374 reconstruct the Keras Graph Network. 2375 2376 Attributes: 2377 layer: The Layer that produced the Tensor. 2378 node_index: The specific call to the Layer that produced this Tensor. Layers 2379 can be called multiple times in order to share weights. A new node is 2380 created every time a Tensor is called. 2381 tensor_index: The output index for this Tensor. Always zero if the Layer 2382 that produced this Tensor only has one output. Nested structures of 2383 Tensors are deterministically assigned an index via `nest.flatten`. 2384 """ 2385 # Added to maintain memory and performance characteristics of `namedtuple` 2386 # while subclassing. 2387 __slots__ = () 2388 2389 2390# Avoid breaking users who directly import this symbol from this file. 2391# TODO(fchollet): remove this. 2392InputSpec = input_spec.InputSpec # pylint:disable=invalid-name 2393