1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the base Layer class, from which all layers inherit.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import functools 24import itertools 25import threading 26import warnings 27import weakref 28 29import numpy as np 30import six 31from six.moves import zip # pylint: disable=redefined-builtin 32 33from google.protobuf import json_format 34from tensorflow.core.framework import node_def_pb2 35from tensorflow.python import tf2 36from tensorflow.python.autograph.core import ag_ctx 37from tensorflow.python.autograph.impl import api as autograph 38from tensorflow.python.distribute import distribution_strategy_context as ds_context 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.eager import execute 42from tensorflow.python.eager import monitoring 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import errors 46from tensorflow.python.framework import func_graph 47from tensorflow.python.framework import ops 48from tensorflow.python.framework import sparse_tensor 49from tensorflow.python.framework import tensor_spec 50from tensorflow.python.framework import tensor_util 51from tensorflow.python.keras import backend 52from tensorflow.python.keras import constraints 53from tensorflow.python.keras import initializers 54from tensorflow.python.keras import regularizers 55from tensorflow.python.keras.engine import base_layer_utils 56from tensorflow.python.keras.engine import input_spec 57from tensorflow.python.keras.engine import keras_tensor 58from tensorflow.python.keras.engine import node as node_module 59from tensorflow.python.keras.mixed_precision import autocast_variable 60from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 61from tensorflow.python.keras.mixed_precision import policy 62from tensorflow.python.keras.saving.saved_model import layer_serialization 63from tensorflow.python.keras.utils import generic_utils 64from tensorflow.python.keras.utils import layer_utils 65from tensorflow.python.keras.utils import object_identity 66from tensorflow.python.keras.utils import tf_inspect 67from tensorflow.python.keras.utils import tf_utils 68from tensorflow.python.keras.utils import version_utils 69# A module that only depends on `keras.layers` import these from here. 70from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import 71from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import 72 73from tensorflow.python.module import module 74from tensorflow.python.ops import array_ops 75from tensorflow.python.ops import math_ops 76from tensorflow.python.ops import variables as tf_variables 77from tensorflow.python.ops.numpy_ops import np_arrays 78from tensorflow.python.ops.ragged import ragged_tensor 79from tensorflow.python.platform import tf_logging 80from tensorflow.python.training.tracking import base as trackable 81from tensorflow.python.training.tracking import data_structures 82from tensorflow.python.training.tracking import tracking 83from tensorflow.python.util import compat 84from tensorflow.python.util import nest 85from tensorflow.python.util.tf_export import get_canonical_name_for_symbol 86from tensorflow.python.util.tf_export import keras_export 87from tensorflow.tools.docs import doc_controls 88 89# pylint: disable=g-inconsistent-quotes 90metrics_mod = generic_utils.LazyLoader( 91 "metrics_mod", globals(), 92 "tensorflow.python.keras.metrics") 93# pylint: enable=g-inconsistent-quotes 94 95# Prefix that is added to the TF op layer names. 96_TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_' 97 98# TODO(mdan): Should we have a single generic type for types that can be passed 99# to tf.cast? 100_AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor, 101 ragged_tensor.RaggedTensor) 102 103keras_layers_gauge = monitoring.BoolGauge('/tensorflow/api/keras/layers', 104 'keras layers usage', 'method') 105keras_models_gauge = monitoring.BoolGauge( 106 '/tensorflow/api/keras/models', 'keras model usage', 'method') 107keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras', 108 'keras api usage', 'method') 109keras_premade_model_gauge = monitoring.BoolGauge( 110 '/tensorflow/api/keras/premade_models', 'premade keras model usage', 'type') 111 112 113@keras_export('keras.layers.Layer') 114class Layer(module.Module, version_utils.LayerVersionSelector): 115 """This is the class from which all layers inherit. 116 117 A layer is a callable object that takes as input one or more tensors and 118 that outputs one or more tensors. It involves *computation*, defined 119 in the `call()` method, and a *state* (weight variables), defined 120 either in the constructor `__init__()` or in the `build()` method. 121 122 Users will just instantiate a layer and then treat it as a callable. 123 124 Args: 125 trainable: Boolean, whether the layer's variables should be trainable. 126 name: String name of the layer. 127 dtype: The dtype of the layer's computations and weights. Can also be a 128 `tf.keras.mixed_precision.Policy`, which allows the computation and weight 129 dtype to differ. Default of `None` means to use 130 `tf.keras.mixed_precision.global_policy()`, which is a float32 policy 131 unless set to different value. 132 dynamic: Set this to `True` if your layer should only be run eagerly, and 133 should not be used to generate a static computation graph. 134 This would be the case for a Tree-RNN or a recursive network, 135 for example, or generally for any layer that manipulates tensors 136 using Python control flow. If `False`, we assume that the layer can 137 safely be used to generate a static computation graph. 138 139 Attributes: 140 name: The name of the layer (string). 141 dtype: The dtype of the layer's weights. 142 variable_dtype: Alias of `dtype`. 143 compute_dtype: The dtype of the layer's computations. Layers automatically 144 cast inputs to this dtype which causes the computations and output to also 145 be in this dtype. When mixed precision is used with a 146 `tf.keras.mixed_precision.Policy`, this will be different than 147 `variable_dtype`. 148 dtype_policy: The layer's dtype policy. See the 149 `tf.keras.mixed_precision.Policy` documentation for details. 150 trainable_weights: List of variables to be included in backprop. 151 non_trainable_weights: List of variables that should not be 152 included in backprop. 153 weights: The concatenation of the lists trainable_weights and 154 non_trainable_weights (in this order). 155 trainable: Whether the layer should be trained (boolean), i.e. whether 156 its potentially-trainable weights should be returned as part of 157 `layer.trainable_weights`. 158 input_spec: Optional (list of) `InputSpec` object(s) specifying the 159 constraints on inputs that can be accepted by the layer. 160 161 We recommend that descendants of `Layer` implement the following methods: 162 163 * `__init__()`: Defines custom layer attributes, and creates layer state 164 variables that do not depend on input shapes, using `add_weight()`. 165 * `build(self, input_shape)`: This method can be used to create weights that 166 depend on the shape(s) of the input(s), using `add_weight()`. `__call__()` 167 will automatically build the layer (if it has not been built yet) by 168 calling `build()`. 169 * `call(self, inputs, *args, **kwargs)`: Called in `__call__` after making 170 sure `build()` has been called. `call()` performs the logic of applying the 171 layer to the input tensors (which should be passed in as argument). 172 Two reserved keyword arguments you can optionally use in `call()` are: 173 - `training` (boolean, whether the call is in inference mode or training 174 mode). See more details in [the layer/model subclassing guide]( 175 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_training_argument_in_the_call_method) 176 - `mask` (boolean tensor encoding masked timesteps in the input, used 177 in RNN layers). See more details in [the layer/model subclassing guide]( 178 https://www.tensorflow.org/guide/keras/custom_layers_and_models#privileged_mask_argument_in_the_call_method) 179 A typical signature for this method is `call(self, inputs)`, and user could 180 optionally add `training` and `mask` if the layer need them. `*args` and 181 `**kwargs` is only useful for future extension when more input parameters 182 are planned to be added. 183 * `get_config(self)`: Returns a dictionary containing the configuration used 184 to initialize this layer. If the keys differ from the arguments 185 in `__init__`, then override `from_config(self)` as well. 186 This method is used when saving 187 the layer or a model that contains this layer. 188 189 Examples: 190 191 Here's a basic example: a layer with two variables, `w` and `b`, 192 that returns `y = w . x + b`. 193 It shows how to implement `build()` and `call()`. 194 Variables set as attributes of a layer are tracked as weights 195 of the layers (in `layer.weights`). 196 197 ```python 198 class SimpleDense(Layer): 199 200 def __init__(self, units=32): 201 super(SimpleDense, self).__init__() 202 self.units = units 203 204 def build(self, input_shape): # Create the state of the layer (weights) 205 w_init = tf.random_normal_initializer() 206 self.w = tf.Variable( 207 initial_value=w_init(shape=(input_shape[-1], self.units), 208 dtype='float32'), 209 trainable=True) 210 b_init = tf.zeros_initializer() 211 self.b = tf.Variable( 212 initial_value=b_init(shape=(self.units,), dtype='float32'), 213 trainable=True) 214 215 def call(self, inputs): # Defines the computation from inputs to outputs 216 return tf.matmul(inputs, self.w) + self.b 217 218 # Instantiates the layer. 219 linear_layer = SimpleDense(4) 220 221 # This will also call `build(input_shape)` and create the weights. 222 y = linear_layer(tf.ones((2, 2))) 223 assert len(linear_layer.weights) == 2 224 225 # These weights are trainable, so they're listed in `trainable_weights`: 226 assert len(linear_layer.trainable_weights) == 2 227 ``` 228 229 Note that the method `add_weight()` offers a shortcut to create weights: 230 231 ```python 232 class SimpleDense(Layer): 233 234 def __init__(self, units=32): 235 super(SimpleDense, self).__init__() 236 self.units = units 237 238 def build(self, input_shape): 239 self.w = self.add_weight(shape=(input_shape[-1], self.units), 240 initializer='random_normal', 241 trainable=True) 242 self.b = self.add_weight(shape=(self.units,), 243 initializer='random_normal', 244 trainable=True) 245 246 def call(self, inputs): 247 return tf.matmul(inputs, self.w) + self.b 248 ``` 249 250 Besides trainable weights, updated via backpropagation during training, 251 layers can also have non-trainable weights. These weights are meant to 252 be updated manually during `call()`. Here's a example layer that computes 253 the running sum of its inputs: 254 255 ```python 256 class ComputeSum(Layer): 257 258 def __init__(self, input_dim): 259 super(ComputeSum, self).__init__() 260 # Create a non-trainable weight. 261 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), 262 trainable=False) 263 264 def call(self, inputs): 265 self.total.assign_add(tf.reduce_sum(inputs, axis=0)) 266 return self.total 267 268 my_sum = ComputeSum(2) 269 x = tf.ones((2, 2)) 270 271 y = my_sum(x) 272 print(y.numpy()) # [2. 2.] 273 274 y = my_sum(x) 275 print(y.numpy()) # [4. 4.] 276 277 assert my_sum.weights == [my_sum.total] 278 assert my_sum.non_trainable_weights == [my_sum.total] 279 assert my_sum.trainable_weights == [] 280 ``` 281 282 For more information about creating layers, see the guide 283 [Making new Layers and Models via subclassing]( 284 https://www.tensorflow.org/guide/keras/custom_layers_and_models) 285 """ 286 287 # See tf.Module for the usage of this property. 288 # The key for _obj_reference_counts_dict is a Trackable, which could be a 289 # variable or layer etc. tf.Module._flatten will fail to flatten the key 290 # since it is trying to convert Trackable to a string. This attribute can be 291 # ignored even after the fix of nest lib, since the trackable object should 292 # already been available as individual attributes. _obj_reference_counts_dict 293 # just contains a copy of them. 294 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 295 ('_obj_reference_counts_dict',), 296 module.Module._TF_MODULE_IGNORED_PROPERTIES 297 )) 298 299 # When loading from a SavedModel, Layers typically can be revived into a 300 # generic Layer wrapper. Sometimes, however, layers may implement methods 301 # that go beyond this wrapper, as in the case of PreprocessingLayers' 302 # `adapt` method. When this is the case, layer implementers can override 303 # must_restore_from_config to return True; layers with this property must 304 # be restored into their actual objects (and will fail if the object is 305 # not available to the restoration code). 306 _must_restore_from_config = False 307 308 def _get_cell_name(self): 309 canonical_name = get_canonical_name_for_symbol( 310 self.__class__, api_name='keras', add_prefix_to_v1_names=True) 311 if canonical_name is not None: 312 return 'tf.{}'.format(canonical_name) 313 return self.__class__.__module__ + '.' + self.__class__.__name__ 314 315 def _instrument_layer_creation(self): 316 self._instrumented_keras_api = False 317 self._instrumented_keras_layer_class = False 318 self._instrumented_keras_model_class = False 319 if not getattr(self, '_disable_keras_instrumentation', False): 320 keras_api_gauge.get_cell('layer').set(True) 321 self._instrumented_keras_api = True 322 if getattr(self, '_is_model_for_instrumentation', False): 323 keras_models_gauge.get_cell(self._get_cell_name()).set(True) 324 self._instrumented_keras_model_class = True 325 else: 326 keras_layers_gauge.get_cell(self._get_cell_name()).set(True) 327 self._instrumented_keras_layer_class = True 328 329 @trackable.no_automatic_dependency_tracking 330 def __init__(self, 331 trainable=True, 332 name=None, 333 dtype=None, 334 dynamic=False, 335 **kwargs): 336 self._instrument_layer_creation() 337 338 # These properties should be set by the user via keyword arguments. 339 # note that 'dtype', 'input_shape' and 'batch_input_shape' 340 # are only applicable to input layers: do not pass these keywords 341 # to non-input layers. 342 allowed_kwargs = { 343 'input_dim', 344 'input_shape', 345 'batch_input_shape', 346 'batch_size', 347 'weights', 348 'activity_regularizer', 349 'autocast', 350 'implementation', 351 } 352 # Validate optional keyword arguments. 353 generic_utils.validate_kwargs(kwargs, allowed_kwargs) 354 355 # Mutable properties 356 # Indicates whether the layer's weights are updated during training 357 # and whether the layer's updates are run during training. 358 self._trainable = trainable 359 # A stateful layer is a layer whose updates are run during inference too, 360 # for instance stateful RNNs. 361 self._stateful = False 362 # Indicates whether `build` needs to be called upon layer call, to create 363 # the layer's weights. 364 self.built = False 365 # Provides information about which inputs are compatible with the layer. 366 self._input_spec = None 367 368 # SavedModel-related attributes. 369 # Record the build input shape for loading purposes. 370 # TODO(kathywu): Move this to Layer._set_save_spec once cl/290121460 is 371 # submitted. 372 self._build_input_shape = None 373 self._saved_model_inputs_spec = None 374 375 # `Layer.compute_mask` will be called at the end of `Layer.__call__` if 376 # `Layer.compute_mask` is overridden, or if the `Layer` subclass sets 377 # `self.supports_masking=True`. 378 self._supports_masking = not generic_utils.is_default(self.compute_mask) 379 380 self._init_set_name(name) 381 self._activity_regularizer = regularizers.get( 382 kwargs.pop('activity_regularizer', None)) 383 self._maybe_create_attribute('_trainable_weights', []) 384 self._maybe_create_attribute('_non_trainable_weights', []) 385 self._updates = [] 386 # Object to store all thread local layer properties. 387 self._thread_local = threading.local() 388 # A list of zero-argument lambdas which return Tensors, used for variable 389 # regularizers. 390 self._callable_losses = [] 391 # A list of symbolic Tensors containing activity regularizers and losses 392 # manually added through `add_loss` in graph-building mode. 393 self._losses = [] 394 # A list of metric instances corresponding to the symbolic metric tensors 395 # added using the `add_metric` API. 396 self._metrics = [] 397 # Ensures the same metric is not added multiple times in `MirroredStrategy`. 398 self._metrics_lock = threading.Lock() 399 400 # Both graph and subclassed networks have a dtype policy. For graph 401 # networks, the policy's compute and variable dtypes are ignored. Such 402 # networks only use the policy if it is a PolicyV1, in which case it uses 403 # the PolicyV1's loss_scale (Policy does not have a loss_scale). For 404 # subclassed networks, the compute and variable dtypes are used as like any 405 # ordinary layer. 406 self._set_dtype_policy(dtype) 407 # Boolean indicating whether the layer automatically casts its inputs to the 408 # layer's compute_dtype. 409 self._autocast = kwargs.get('autocast', 410 base_layer_utils.v2_dtype_behavior_enabled()) 411 412 # Tracks `TrackableDataStructure`s, `Module`s, and `Layer`s. 413 # Ordered by when the object was assigned as an attr. 414 # Entries are unique. 415 self._maybe_create_attribute('_self_tracked_trackables', []) 416 417 # These lists will be filled via successive calls 418 # to self._add_inbound_node(). 419 # Used in symbolic mode only, only in conjunction with graph-networks 420 self._inbound_nodes_value = [] 421 self._outbound_nodes_value = [] 422 423 self._init_call_fn_args() 424 425 # Whether the `call` method can be used to build a TF graph without issues. 426 # This attribute has no effect if the model is created using the Functional 427 # API. Instead, `model.dynamic` is determined based on the internal layers. 428 self._dynamic = dynamic 429 430 # Manage input shape information if passed. 431 if 'input_dim' in kwargs and 'input_shape' not in kwargs: 432 # Backwards compatibility: alias 'input_dim' to 'input_shape'. 433 kwargs['input_shape'] = (kwargs['input_dim'],) 434 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 435 # In this case we will later create an input layer 436 # to insert before the current layer 437 if 'batch_input_shape' in kwargs: 438 batch_input_shape = tuple(kwargs['batch_input_shape']) 439 elif 'input_shape' in kwargs: 440 if 'batch_size' in kwargs: 441 batch_size = kwargs['batch_size'] 442 else: 443 batch_size = None 444 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 445 self._batch_input_shape = batch_input_shape 446 447 # Manage initial weight values if passed. 448 self._initial_weights = kwargs.get('weights', None) 449 450 # Whether the layer will track any layers that is set as attribute on itself 451 # as sub-layers, the weights from the sub-layers will be included in the 452 # parent layer's variables() as well. 453 # Default to True, which means auto tracking is turned on. Certain subclass 454 # might want to turn it off, like Sequential model. 455 self._auto_track_sub_layers = True 456 457 # For backwards compat reasons, most built-in layers do not guarantee 458 # That they will 100% preserve the structure of input args when saving 459 # / loading configs. E.g. they may un-nest an arg that is 460 # a list with one element. 461 self._preserve_input_structure_in_config = False 462 463 @trackable.no_automatic_dependency_tracking 464 @generic_utils.default 465 def build(self, input_shape): 466 """Creates the variables of the layer (optional, for subclass implementers). 467 468 This is a method that implementers of subclasses of `Layer` or `Model` 469 can override if they need a state-creation step in-between 470 layer instantiation and layer call. 471 472 This is typically used to create the weights of `Layer` subclasses. 473 474 Args: 475 input_shape: Instance of `TensorShape`, or list of instances of 476 `TensorShape` if the layer expects a list of inputs 477 (one instance per input). 478 """ 479 # Only record the build input shapes of overridden build methods. 480 if not hasattr(self.build, '_is_default'): 481 self._build_input_shape = input_shape 482 self.built = True 483 484 @doc_controls.for_subclass_implementers 485 def call(self, inputs, *args, **kwargs): # pylint: disable=unused-argument 486 """This is where the layer's logic lives. 487 488 Note here that `call()` method in `tf.keras` is little bit different 489 from `keras` API. In `keras` API, you can pass support masking for 490 layers as additional arguments. Whereas `tf.keras` has `compute_mask()` 491 method to support masking. 492 493 Args: 494 inputs: Input tensor, or list/tuple of input tensors. 495 *args: Additional positional arguments. Currently unused. 496 **kwargs: Additional keyword arguments. Currently unused. 497 498 Returns: 499 A tensor or list/tuple of tensors. 500 """ 501 return inputs 502 503 @doc_controls.for_subclass_implementers 504 def _add_trackable(self, trackable_object, trainable): 505 """Adds a Trackable object to this layer's state. 506 507 Args: 508 trackable_object: The tf.tracking.Trackable object to add. 509 trainable: Boolean, whether the variable should be part of the layer's 510 "trainable_variables" (e.g. variables, biases) or 511 "non_trainable_variables" (e.g. BatchNorm mean and variance). 512 513 Returns: 514 The TrackableWeightHandler used to track this object. 515 """ 516 handler = base_layer_utils.TrackableWeightHandler(trackable_object) 517 if trainable: 518 self._trainable_weights.append(handler) 519 else: 520 self._non_trainable_weights.append(handler) 521 return handler 522 523 @doc_controls.for_subclass_implementers 524 def add_weight(self, 525 name=None, 526 shape=None, 527 dtype=None, 528 initializer=None, 529 regularizer=None, 530 trainable=None, 531 constraint=None, 532 use_resource=None, 533 synchronization=tf_variables.VariableSynchronization.AUTO, 534 aggregation=tf_variables.VariableAggregation.NONE, 535 **kwargs): 536 """Adds a new variable to the layer. 537 538 Args: 539 name: Variable name. 540 shape: Variable shape. Defaults to scalar if unspecified. 541 dtype: The type of the variable. Defaults to `self.dtype`. 542 initializer: Initializer instance (callable). 543 regularizer: Regularizer instance (callable). 544 trainable: Boolean, whether the variable should be part of the layer's 545 "trainable_variables" (e.g. variables, biases) 546 or "non_trainable_variables" (e.g. BatchNorm mean and variance). 547 Note that `trainable` cannot be `True` if `synchronization` 548 is set to `ON_READ`. 549 constraint: Constraint instance (callable). 550 use_resource: Whether to use `ResourceVariable`. 551 synchronization: Indicates when a distributed a variable will be 552 aggregated. Accepted values are constants defined in the class 553 `tf.VariableSynchronization`. By default the synchronization is set to 554 `AUTO` and the current `DistributionStrategy` chooses 555 when to synchronize. If `synchronization` is set to `ON_READ`, 556 `trainable` must not be set to `True`. 557 aggregation: Indicates how a distributed variable will be aggregated. 558 Accepted values are constants defined in the class 559 `tf.VariableAggregation`. 560 **kwargs: Additional keyword arguments. Accepted values are `getter`, 561 `collections`, `experimental_autocast` and `caching_device`. 562 563 Returns: 564 The variable created. 565 566 Raises: 567 ValueError: When giving unsupported dtype and no initializer or when 568 trainable has been set to True with synchronization set as `ON_READ`. 569 """ 570 if shape is None: 571 shape = () 572 kwargs.pop('partitioner', None) # Ignored. 573 # Validate optional keyword arguments. 574 for kwarg in kwargs: 575 if kwarg not in ['collections', 'experimental_autocast', 576 'caching_device', 'getter']: 577 raise TypeError('Unknown keyword argument:', kwarg) 578 collections_arg = kwargs.pop('collections', None) 579 # 'experimental_autocast' can be set to False by the caller to indicate an 580 # AutoCastVariable should never be created. 581 autocast = kwargs.pop('experimental_autocast', True) 582 # See the docstring for tf.Variable about the details for caching_device. 583 caching_device = kwargs.pop('caching_device', None) 584 585 if dtype is None: 586 dtype = self.dtype or backend.floatx() 587 dtype = dtypes.as_dtype(dtype) 588 if self._dtype_policy.variable_dtype is None: 589 # The policy is "_infer", so we infer the policy from the variable dtype. 590 self._set_dtype_policy(policy.Policy(dtype.base_dtype.name)) 591 initializer = initializers.get(initializer) 592 regularizer = regularizers.get(regularizer) 593 constraint = constraints.get(constraint) 594 595 if synchronization == tf_variables.VariableSynchronization.ON_READ: 596 if trainable: 597 raise ValueError( 598 'Synchronization value can be set to ' 599 'VariableSynchronization.ON_READ only for non-trainable variables. ' 600 'You have specified trainable=True and ' 601 'synchronization=VariableSynchronization.ON_READ.') 602 else: 603 # Set trainable to be false when variable is to be synced on read. 604 trainable = False 605 elif trainable is None: 606 trainable = True 607 608 # Initialize variable when no initializer provided 609 if initializer is None: 610 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 611 if dtype.is_floating: 612 initializer = initializers.get('glorot_uniform') 613 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 614 # If dtype is DT_BOOL, provide a default value `FALSE` 615 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 616 initializer = initializers.get('zeros') 617 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 618 else: 619 raise ValueError('An initializer for variable %s of type %s is required' 620 ' for layer %s' % (name, dtype.base_dtype, self.name)) 621 622 getter = kwargs.pop('getter', base_layer_utils.make_variable) 623 if (autocast and 624 self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype 625 and dtype.is_floating): 626 old_getter = getter 627 # Wrap variable constructor to return an AutoCastVariable. 628 def getter(*args, **kwargs): # pylint: disable=function-redefined 629 variable = old_getter(*args, **kwargs) 630 return autocast_variable.create_autocast_variable(variable) 631 # Also the caching_device does not work with the mixed precision API, 632 # disable it if it is specified. 633 # TODO(b/142020079): Reenable it once the bug is fixed. 634 if caching_device is not None: 635 tf_logging.warn('`caching_device` does not work with mixed precision ' 636 'API. Ignoring user specified `caching_device`.') 637 caching_device = None 638 639 variable = self._add_variable_with_custom_getter( 640 name=name, 641 shape=shape, 642 # TODO(allenl): a `make_variable` equivalent should be added as a 643 # `Trackable` method. 644 getter=getter, 645 # Manage errors in Layer rather than Trackable. 646 overwrite=True, 647 initializer=initializer, 648 dtype=dtype, 649 constraint=constraint, 650 trainable=trainable, 651 use_resource=use_resource, 652 collections=collections_arg, 653 synchronization=synchronization, 654 aggregation=aggregation, 655 caching_device=caching_device) 656 if regularizer is not None: 657 # TODO(fchollet): in the future, this should be handled at the 658 # level of variable creation, and weight regularization losses 659 # should be variable attributes. 660 name_in_scope = variable.name[:variable.name.find(':')] 661 self._handle_weight_regularization(name_in_scope, 662 variable, 663 regularizer) 664 if base_layer_utils.is_split_variable(variable): 665 for v in variable: 666 backend.track_variable(v) 667 if trainable: 668 self._trainable_weights.append(v) 669 else: 670 self._non_trainable_weights.append(v) 671 else: 672 backend.track_variable(variable) 673 if trainable: 674 self._trainable_weights.append(variable) 675 else: 676 self._non_trainable_weights.append(variable) 677 return variable 678 679 @generic_utils.default 680 def get_config(self): 681 """Returns the config of the layer. 682 683 A layer config is a Python dictionary (serializable) 684 containing the configuration of a layer. 685 The same layer can be reinstantiated later 686 (without its trained weights) from this configuration. 687 688 The config of a layer does not include connectivity 689 information, nor the layer class name. These are handled 690 by `Network` (one layer of abstraction above). 691 692 Note that `get_config()` does not guarantee to return a fresh copy of dict 693 every time it is called. The callers should make a copy of the returned dict 694 if they want to modify it. 695 696 Returns: 697 Python dictionary. 698 """ 699 all_args = tf_inspect.getfullargspec(self.__init__).args 700 config = { 701 'name': self.name, 702 'trainable': self.trainable, 703 } 704 if hasattr(self, '_batch_input_shape'): 705 config['batch_input_shape'] = self._batch_input_shape 706 config['dtype'] = policy.serialize(self._dtype_policy) 707 if hasattr(self, 'dynamic'): 708 # Only include `dynamic` in the `config` if it is `True` 709 if self.dynamic: 710 config['dynamic'] = self.dynamic 711 elif 'dynamic' in all_args: 712 all_args.remove('dynamic') 713 expected_args = config.keys() 714 # Finds all arguments in the `__init__` that are not in the config: 715 extra_args = [arg for arg in all_args if arg not in expected_args] 716 # Check that either the only argument in the `__init__` is `self`, 717 # or that `get_config` has been overridden: 718 if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'): 719 raise NotImplementedError('Layer %s has arguments in `__init__` and ' 720 'therefore must override `get_config`.' % 721 self.__class__.__name__) 722 return config 723 724 @classmethod 725 def from_config(cls, config): 726 """Creates a layer from its config. 727 728 This method is the reverse of `get_config`, 729 capable of instantiating the same layer from the config 730 dictionary. It does not handle layer connectivity 731 (handled by Network), nor weights (handled by `set_weights`). 732 733 Args: 734 config: A Python dictionary, typically the 735 output of get_config. 736 737 Returns: 738 A layer instance. 739 """ 740 return cls(**config) 741 742 def compute_output_shape(self, input_shape): 743 """Computes the output shape of the layer. 744 745 If the layer has not been built, this method will call `build` on the 746 layer. This assumes that the layer will later be used with inputs that 747 match the input shape provided here. 748 749 Args: 750 input_shape: Shape tuple (tuple of integers) 751 or list of shape tuples (one per output tensor of the layer). 752 Shape tuples can include None for free dimensions, 753 instead of an integer. 754 755 Returns: 756 An input shape tuple. 757 """ 758 if context.executing_eagerly(): 759 # In this case we build the model first in order to do shape inference. 760 # This is acceptable because the framework only calls 761 # `compute_output_shape` on shape values that the layer would later be 762 # built for. It would however cause issues in case a user attempts to 763 # use `compute_output_shape` manually with shapes that are incompatible 764 # with the shape the Layer will be called on (these users will have to 765 # implement `compute_output_shape` themselves). 766 self._maybe_build(input_shape) 767 with func_graph.FuncGraph(str(self.name) + '_scratch_graph').as_default(): 768 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 769 def _make_placeholder_like(shape): 770 ph = backend.placeholder(shape=shape, dtype=self.dtype) 771 ph._keras_mask = None 772 return ph 773 inputs = nest.map_structure(_make_placeholder_like, input_shape) 774 try: 775 outputs = self(inputs, training=False) 776 except TypeError as e: 777 six.raise_from( 778 NotImplementedError( 779 'We could not automatically infer the static shape of the ' 780 'layer\'s output. Please implement the ' 781 '`compute_output_shape` method on your layer (%s).' % 782 self.__class__.__name__), e) 783 return nest.map_structure(lambda t: t.shape, outputs) 784 raise NotImplementedError( 785 'Please run in eager mode or implement the `compute_output_shape` ' 786 'method on your layer (%s).' % self.__class__.__name__) 787 788 @doc_controls.for_subclass_implementers 789 def compute_output_signature(self, input_signature): 790 """Compute the output tensor signature of the layer based on the inputs. 791 792 Unlike a TensorShape object, a TensorSpec object contains both shape 793 and dtype information for a tensor. This method allows layers to provide 794 output dtype information if it is different from the input dtype. 795 For any layer that doesn't implement this function, 796 the framework will fall back to use `compute_output_shape`, and will 797 assume that the output dtype matches the input dtype. 798 799 Args: 800 input_signature: Single TensorSpec or nested structure of TensorSpec 801 objects, describing a candidate input for the layer. 802 803 Returns: 804 Single TensorSpec or nested structure of TensorSpec objects, describing 805 how the layer would transform the provided input. 806 807 Raises: 808 TypeError: If input_signature contains a non-TensorSpec object. 809 """ 810 def check_type_return_shape(s): 811 if not isinstance(s, tensor_spec.TensorSpec): 812 raise TypeError('Only TensorSpec signature types are supported, ' 813 'but saw signature entry: {}.'.format(s)) 814 return s.shape 815 input_shape = nest.map_structure(check_type_return_shape, input_signature) 816 output_shape = self.compute_output_shape(input_shape) 817 dtype = self._compute_dtype 818 if dtype is None: 819 input_dtypes = [s.dtype for s in nest.flatten(input_signature)] 820 # Default behavior when self.dtype is None, is to use the first input's 821 # dtype. 822 dtype = input_dtypes[0] 823 return nest.map_structure( 824 lambda s: tensor_spec.TensorSpec(dtype=dtype, shape=s), 825 output_shape) 826 827 def _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs): 828 if self.dynamic: 829 # We will use static shape inference to return symbolic tensors 830 # matching the specifications of the layer outputs. 831 # Since `self.dynamic` is True, we will never attempt to 832 # run the underlying TF graph (which is disconnected). 833 # TODO(fchollet): consider py_func as an alternative, which 834 # would enable us to run the underlying graph if needed. 835 input_signature = nest.map_structure( 836 lambda x: tensor_spec.TensorSpec(shape=x.shape, dtype=x.dtype), 837 inputs) 838 output_signature = self.compute_output_signature(input_signature) 839 return nest.map_structure(keras_tensor.KerasTensor, output_signature) 840 else: 841 return self._infer_output_signature(inputs, args, kwargs, input_masks) 842 843 def _infer_output_signature(self, inputs, args, kwargs, input_masks): 844 """TODO(kaftan): Docstring.""" 845 846 call_fn = self.call 847 # Wrapping `call` function in autograph to allow for dynamic control 848 # flow and control dependencies in call. We are limiting this to 849 # subclassed layers as autograph is strictly needed only for 850 # subclassed layers and models. 851 # tf_convert will respect the value of autograph setting in the 852 # enclosing tf.function, if any. 853 if (base_layer_utils.is_subclassed(self) and 854 not base_layer_utils.from_saved_model(self)): 855 call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 856 857 # We enter a scratch graph and build placeholder inputs inside of it that 858 # match the input args. 859 # We then call the layer inside of the scratch graph to identify the 860 # output signatures, then we build KerasTensors corresponding to those 861 # outputs. 862 scratch_graph = func_graph.FuncGraph(str(self.name) + '_scratch_graph') 863 with scratch_graph.as_default(): 864 inputs = nest.map_structure( 865 keras_tensor.keras_tensor_to_placeholder, inputs) 866 args = nest.map_structure( 867 keras_tensor.keras_tensor_to_placeholder, args) 868 kwargs = nest.map_structure( 869 keras_tensor.keras_tensor_to_placeholder, kwargs) 870 input_masks = nest.map_structure( 871 keras_tensor.keras_tensor_to_placeholder, input_masks) 872 873 inputs = self._maybe_cast_inputs(inputs) 874 875 with backend.name_scope(self._name_scope()): 876 with autocast_variable.enable_auto_cast_variables( 877 self._compute_dtype_object): 878 # Build layer if applicable (if the `build` method has been 879 # overridden). 880 # TODO(kaftan): do we maybe_build here, or have we already done it? 881 self._maybe_build(inputs) 882 outputs = call_fn(inputs, *args, **kwargs) 883 884 self._handle_activity_regularization(inputs, outputs) 885 self._set_mask_metadata(inputs, outputs, input_masks, 886 build_graph=False) 887 outputs = nest.map_structure( 888 keras_tensor.keras_tensor_from_tensor, outputs) 889 890 if hasattr(self, '_set_inputs') and not self.inputs: 891 # TODO(kaftan): figure out if we need to do this at all 892 # Subclassed network: explicitly set metadata normally set by 893 # a call to self._set_inputs(). 894 self._set_inputs(inputs, outputs) 895 del scratch_graph 896 return outputs 897 898 @generic_utils.default 899 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 900 """Computes an output mask tensor. 901 902 Args: 903 inputs: Tensor or list of tensors. 904 mask: Tensor or list of tensors. 905 906 Returns: 907 None or a tensor (or list of tensors, 908 one per output tensor of the layer). 909 """ 910 if not self._supports_masking: 911 if any(m is not None for m in nest.flatten(mask)): 912 raise TypeError('Layer ' + self.name + ' does not support masking, ' 913 'but was passed an input_mask: ' + str(mask)) 914 # masking not explicitly supported: return None as mask. 915 return None 916 # if masking is explicitly supported, by default 917 # carry over the input mask 918 return mask 919 920 def __call__(self, *args, **kwargs): 921 """Wraps `call`, applying pre- and post-processing steps. 922 923 Args: 924 *args: Positional arguments to be passed to `self.call`. 925 **kwargs: Keyword arguments to be passed to `self.call`. 926 927 Returns: 928 Output tensor(s). 929 930 Note: 931 - The following optional keyword arguments are reserved for specific uses: 932 * `training`: Boolean scalar tensor of Python boolean indicating 933 whether the `call` is meant for training or inference. 934 * `mask`: Boolean input mask. 935 - If the layer's `call` method takes a `mask` argument (as some Keras 936 layers do), its default value will be set to the mask generated 937 for `inputs` by the previous layer (if `input` did come from 938 a layer that generated a corresponding mask, i.e. if it came from 939 a Keras layer with masking support. 940 - If the layer is not built, the method will call `build`. 941 942 Raises: 943 ValueError: if the layer's `call` method returns None (an invalid value). 944 RuntimeError: if `super().__init__()` was not called in the constructor. 945 """ 946 if not hasattr(self, '_thread_local'): 947 raise RuntimeError( 948 'You must call `super().__init__()` in the layer constructor.') 949 950 # `inputs` (the first arg in the method spec) is special cased in 951 # layer call due to historical reasons. 952 # This special casing currently takes the form of: 953 # - 'inputs' must be explicitly passed. A layer cannot have zero arguments, 954 # and inputs cannot have been provided via the default value of a kwarg. 955 # - numpy/scalar values in `inputs` get converted to tensors 956 # - implicit masks / mask metadata are only collected from 'inputs` 957 # - Layers are built using shape info from 'inputs' only 958 # - input_spec compatibility is only checked against `inputs` 959 # - mixed precision casting (autocast) is only applied to `inputs`, 960 # not to any other argument. 961 # - setting the SavedModel saving spec. 962 inputs, args, kwargs = self._split_out_first_arg(args, kwargs) 963 input_list = nest.flatten(inputs) 964 965 # Functional Model construction mode is invoked when `Layer`s are called on 966 # symbolic `KerasTensor`s, i.e.: 967 # >> inputs = tf.keras.Input(10) 968 # >> outputs = MyLayer()(inputs) # Functional construction mode. 969 # >> model = tf.keras.Model(inputs, outputs) 970 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list): 971 return self._functional_construction_call(inputs, args, kwargs, 972 input_list) 973 974 # Maintains info about the `Layer.call` stack. 975 call_context = base_layer_utils.call_context() 976 977 # Accept NumPy and scalar inputs by converting to Tensors. 978 if any(isinstance(x, ( 979 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): 980 inputs = nest.map_structure(_convert_numpy_or_python_types, inputs) 981 input_list = nest.flatten(inputs) 982 983 # Handle `mask` propagation from previous layer to current layer. Masks can 984 # be propagated explicitly via the `mask` argument, or implicitly via 985 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 986 # explicitly take priority. 987 input_masks, mask_is_implicit = self._get_input_masks( 988 inputs, input_list, args, kwargs) 989 if self._expects_mask_arg and mask_is_implicit: 990 kwargs['mask'] = input_masks 991 992 # Training mode for `Layer.call` is set via (in order of priority): 993 # (1) The `training` argument passed to this `Layer.call`, if it is not None 994 # (2) The training mode of an outer `Layer.call`. 995 # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set) 996 # (4) Any non-None default value for `training` specified in the call 997 # signature 998 # (5) False (treating the layer as if it's in inference) 999 args, kwargs, training_mode = self._set_training_mode( 1000 args, kwargs, call_context) 1001 1002 # Losses are cleared for all sublayers on the outermost `Layer.call`. 1003 # Losses are not cleared on inner `Layer.call`s, because sublayers can be 1004 # called multiple times. 1005 if not call_context.in_call: 1006 self._clear_losses() 1007 1008 eager = context.executing_eagerly() 1009 with call_context.enter( 1010 layer=self, 1011 inputs=inputs, 1012 build_graph=not eager, 1013 training=training_mode): 1014 1015 if self._autocast: 1016 inputs = self._maybe_cast_inputs(inputs, input_list) 1017 1018 input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) 1019 if eager: 1020 call_fn = self.call 1021 name_scope = self._name 1022 else: 1023 name_scope = self._name_scope() # Avoid autoincrementing. 1024 call_fn = self._autographed_call() 1025 1026 with ops.name_scope_v2(name_scope): 1027 if not self.built: 1028 self._maybe_build(inputs) 1029 1030 with autocast_variable.enable_auto_cast_variables( 1031 self._compute_dtype_object): 1032 outputs = call_fn(inputs, *args, **kwargs) 1033 1034 if self._activity_regularizer: 1035 self._handle_activity_regularization(inputs, outputs) 1036 if self._supports_masking: 1037 self._set_mask_metadata(inputs, outputs, input_masks, not eager) 1038 if self._saved_model_inputs_spec is None: 1039 self._set_save_spec(inputs) 1040 1041 return outputs 1042 1043 def _functional_construction_call(self, inputs, args, kwargs, input_list): 1044 call_context = base_layer_utils.call_context() 1045 1046 # Accept NumPy and scalar inputs by converting to Tensors. 1047 if any(isinstance(x, ( 1048 np_arrays.ndarray, np.ndarray, float, int)) for x in input_list): 1049 1050 def _convert_non_tensor(x): 1051 # Don't call `ops.convert_to_tensor` on all `inputs` because 1052 # `SparseTensors` can't be converted to `Tensor`. 1053 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): 1054 return ops.convert_to_tensor_v2_with_dispatch(x) 1055 return x 1056 1057 inputs = nest.map_structure(_convert_non_tensor, inputs) 1058 input_list = nest.flatten(inputs) 1059 1060 # Handle `mask` propagation from previous layer to current layer. Masks can 1061 # be propagated explicitly via the `mask` argument, or implicitly via 1062 # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed 1063 # explicitly take priority. 1064 mask_arg_passed_by_framework = False 1065 input_masks, mask_is_implicit = self._get_input_masks( 1066 inputs, input_list, args, kwargs) 1067 if self._expects_mask_arg and mask_is_implicit: 1068 kwargs['mask'] = input_masks 1069 mask_arg_passed_by_framework = True 1070 1071 # If `training` argument is None or not explicitly passed, 1072 # propagate `training` value from this layer's calling layer. 1073 training_value = None 1074 training_arg_passed_by_framework = False 1075 # Priority 1: `training` was explicitly passed a non-None value. 1076 if self._call_arg_was_passed('training', args, kwargs): 1077 training_value = self._get_call_arg_value('training', args, kwargs) 1078 if not self._expects_training_arg: 1079 kwargs.pop('training') 1080 1081 if training_value is None: 1082 # Priority 2: `training` was passed to a parent layer. 1083 if call_context.training is not None: 1084 training_value = call_context.training 1085 # Priority 3: `learning_phase()` has been set. 1086 elif backend.global_learning_phase_is_set(): 1087 training_value = backend.learning_phase() 1088 # Force the training_value to be bool type which matches to the contract 1089 # for layer/model call args. 1090 if tensor_util.is_tf_type(training_value): 1091 training_value = math_ops.cast(training_value, dtypes.bool) 1092 else: 1093 training_value = bool(training_value) 1094 # Priority 4: trace layer with the default training argument specified 1095 # in the `call` signature (or in inference mode if the `call` signature 1096 # specifies no non-None default). 1097 else: 1098 training_value = self._default_training_arg 1099 # In cases (2), (3), (4) the training argument is passed automatically 1100 # by the framework, and will not be hard-coded into the model. 1101 if self._expects_training_arg: 1102 args, kwargs = self._set_call_arg_value('training', training_value, 1103 args, kwargs) 1104 training_arg_passed_by_framework = True 1105 1106 if keras_tensor.keras_tensors_enabled(): 1107 with call_context.enter( 1108 layer=self, inputs=inputs, build_graph=True, training=training_value): 1109 # Check input assumptions set after layer building, e.g. input shape. 1110 outputs = self._keras_tensor_symbolic_call( 1111 inputs, input_masks, args, kwargs) 1112 1113 if outputs is None: 1114 raise ValueError('A layer\'s `call` method should return a ' 1115 'Tensor or a list of Tensors, not None ' 1116 '(layer: ' + self.name + ').') 1117 if training_arg_passed_by_framework: 1118 args, kwargs = self._set_call_arg_value( 1119 'training', None, args, kwargs, pop_kwarg_if_none=True) 1120 if mask_arg_passed_by_framework: 1121 kwargs.pop('mask') 1122 # Node connectivity does not special-case the first argument. 1123 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, 1124 outputs) 1125 return outputs 1126 1127 # Only create Keras history if at least one tensor originates from a 1128 # `keras.Input`. Otherwise this Layer may be being used outside the Keras 1129 # framework. 1130 # TODO(kaftan): make this not special case inputs 1131 if base_layer_utils.needs_keras_history(inputs): 1132 base_layer_utils.create_keras_history(inputs) 1133 1134 with call_context.enter( 1135 layer=self, inputs=inputs, build_graph=True, training=training_value): 1136 # Symbolic execution on symbolic tensors. We will attempt to build 1137 # the corresponding TF subgraph inside `backend.get_graph()` 1138 # TODO(reedwm): We should assert input compatibility after the inputs 1139 # are casted, not before. 1140 input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) 1141 graph = backend.get_graph() 1142 # Use `self._name_scope()` to avoid auto-incrementing the name. 1143 with graph.as_default(), backend.name_scope(self._name_scope()): 1144 # Build layer if applicable (if the `build` method has been 1145 # overridden). 1146 self._maybe_build(inputs) 1147 cast_inputs = self._maybe_cast_inputs(inputs, input_list) 1148 1149 if not self.dynamic: 1150 # Wrapping `call` function in autograph to allow for dynamic control 1151 # flow and control dependencies in call. We are limiting this to 1152 # subclassed layers as autograph is strictly needed only for 1153 # subclassed layers and models. 1154 # tf_convert will respect the value of autograph setting in the 1155 # enclosing tf.function, if any. 1156 if (base_layer_utils.is_subclassed(self) and 1157 not base_layer_utils.from_saved_model(self)): 1158 call_fn = autograph.tf_convert(self.call, 1159 ag_ctx.control_status_ctx()) 1160 else: 1161 call_fn = self.call 1162 1163 try: 1164 with autocast_variable.enable_auto_cast_variables( 1165 self._compute_dtype_object): 1166 outputs = call_fn(cast_inputs, *args, **kwargs) 1167 1168 except errors.OperatorNotAllowedInGraphError as e: 1169 raise TypeError('You are attempting to use Python control ' 1170 'flow in a layer that was not declared to be ' 1171 'dynamic. Pass `dynamic=True` to the class ' 1172 'constructor.\nEncountered error:\n"""\n' + str(e) + 1173 '\n"""') 1174 else: 1175 # We will use static shape inference to return symbolic tensors 1176 # matching the specifications of the layer outputs. 1177 # Since `self.dynamic` is True, we will never attempt to 1178 # run the underlying TF graph (which is disconnected). 1179 # TODO(fchollet): consider py_func as an alternative, which 1180 # would enable us to run the underlying graph if needed. 1181 outputs = self._symbolic_call(inputs) 1182 1183 if outputs is None: 1184 raise ValueError('A layer\'s `call` method should return a ' 1185 'Tensor or a list of Tensors, not None ' 1186 '(layer: ' + self.name + ').') 1187 # TODO(kaftan): This should be 'any' and check all args 1188 if base_layer_utils.have_all_keras_metadata(inputs): 1189 if training_arg_passed_by_framework: 1190 args, kwargs = self._set_call_arg_value( 1191 'training', None, args, kwargs, pop_kwarg_if_none=True) 1192 if mask_arg_passed_by_framework: 1193 kwargs.pop('mask') 1194 # Node connectivity does not special-case the first argument. 1195 outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, 1196 outputs) 1197 self._handle_activity_regularization(inputs, outputs) 1198 self._set_mask_metadata(inputs, outputs, input_masks, True) 1199 if hasattr(self, '_set_inputs') and not self.inputs: 1200 # Subclassed network: explicitly set metadata normally set by 1201 # a call to self._set_inputs(). 1202 self._set_inputs(cast_inputs, outputs) 1203 1204 return outputs 1205 1206 def _set_training_mode(self, args, kwargs, call_context): 1207 training_mode = None 1208 if self._expects_training_arg: 1209 # (1) `training` was passed to this `Layer.call`. 1210 if self._call_arg_was_passed('training', args, kwargs): 1211 training_mode = self._get_call_arg_value('training', args, kwargs) 1212 # If no `training` arg was passed, or `None` was explicitly passed, 1213 # the framework will make a decision about the training mode is. 1214 if training_mode is None: 1215 call_ctx_training = call_context.training 1216 # (2) `training` mode is inferred from an outer `Layer.call`. 1217 if call_ctx_training is not None: 1218 training_mode = call_ctx_training 1219 # (3) User set `tf.keras.backend.set_learning_phase`. 1220 elif backend.global_learning_phase_is_set(): 1221 training_mode = backend.learning_phase() 1222 # Ensure value is a `bool` or `tf.bool`. 1223 if isinstance(training_mode, bool): 1224 pass 1225 elif tensor_util.is_tf_type(training_mode): 1226 training_mode = math_ops.cast(training_mode, dtypes.bool) 1227 else: 1228 training_mode = bool(training_mode) 1229 # (4) We default to using `call`'s default value for `training`, 1230 # or treating the layer as if it is in inference if no non-None default 1231 # is specified in the `call` signature. 1232 else: 1233 training_mode = self._default_training_arg 1234 1235 # For case (2), (3), (4) `training` arg is passed by framework. 1236 args, kwargs = self._set_call_arg_value('training', training_mode, args, 1237 kwargs) 1238 else: 1239 if 'training' in kwargs: 1240 # `training` was passed to this `Layer` but is not needed for 1241 # `Layer.call`. It will set the default mode for inner `Layer.call`s. 1242 training_mode = kwargs.pop('training') 1243 else: 1244 # Grab the current `training` mode from any outer `Layer.call`. 1245 training_mode = call_context.training 1246 1247 return args, kwargs, training_mode 1248 1249 def _autographed_call(self): 1250 # Wrapping `call` function in autograph to allow for dynamic control 1251 # flow and control dependencies in call. We are limiting this to 1252 # subclassed layers as autograph is strictly needed only for 1253 # subclassed layers and models. 1254 # tf_convert will respect the value of autograph setting in the 1255 # enclosing tf.function, if any. 1256 if (base_layer_utils.is_subclassed(self) and 1257 not base_layer_utils.from_saved_model(self)): 1258 return autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 1259 else: 1260 return self.call 1261 1262 @property 1263 def dtype(self): 1264 """The dtype of the layer weights. 1265 1266 This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless 1267 mixed precision is used, this is the same as `Layer.compute_dtype`, the 1268 dtype of the layer's computations. 1269 """ 1270 return self._dtype_policy.variable_dtype 1271 1272 @property 1273 def name(self): 1274 """Name of the layer (string), set in the constructor.""" 1275 return self._name 1276 1277 @property 1278 def supports_masking(self): 1279 """Whether this layer supports computing a mask using `compute_mask`.""" 1280 return self._supports_masking 1281 1282 @supports_masking.setter 1283 def supports_masking(self, value): 1284 self._supports_masking = value 1285 1286 @property 1287 def dynamic(self): 1288 """Whether the layer is dynamic (eager-only); set in the constructor.""" 1289 return any(layer._dynamic for layer in self._flatten_layers()) 1290 1291 @property 1292 @doc_controls.do_not_doc_inheritable 1293 def stateful(self): 1294 return any(layer._stateful for layer in self._flatten_layers()) 1295 1296 @stateful.setter 1297 def stateful(self, value): 1298 self._stateful = value 1299 1300 @property 1301 def trainable(self): 1302 return self._trainable 1303 1304 @trainable.setter 1305 def trainable(self, value): 1306 for layer in self._flatten_layers(): 1307 layer._trainable = value 1308 1309 @property 1310 def activity_regularizer(self): 1311 """Optional regularizer function for the output of this layer.""" 1312 return self._activity_regularizer 1313 1314 @activity_regularizer.setter 1315 def activity_regularizer(self, regularizer): 1316 """Optional regularizer function for the output of this layer.""" 1317 self._activity_regularizer = regularizer 1318 1319 @property 1320 def input_spec(self): 1321 """`InputSpec` instance(s) describing the input format for this layer. 1322 1323 When you create a layer subclass, you can set `self.input_spec` to enable 1324 the layer to run input compatibility checks when it is called. 1325 Consider a `Conv2D` layer: it can only be called on a single input tensor 1326 of rank 4. As such, you can set, in `__init__()`: 1327 1328 ```python 1329 self.input_spec = tf.keras.layers.InputSpec(ndim=4) 1330 ``` 1331 1332 Now, if you try to call the layer on an input that isn't rank 4 1333 (for instance, an input of shape `(2,)`, it will raise a nicely-formatted 1334 error: 1335 1336 ``` 1337 ValueError: Input 0 of layer conv2d is incompatible with the layer: 1338 expected ndim=4, found ndim=1. Full shape received: [2] 1339 ``` 1340 1341 Input checks that can be specified via `input_spec` include: 1342 - Structure (e.g. a single input, a list of 2 inputs, etc) 1343 - Shape 1344 - Rank (ndim) 1345 - Dtype 1346 1347 For more information, see `tf.keras.layers.InputSpec`. 1348 1349 Returns: 1350 A `tf.keras.layers.InputSpec` instance, or nested structure thereof. 1351 """ 1352 return self._input_spec 1353 1354 @input_spec.setter 1355 # Must be decorated to prevent tracking, since the input_spec can be nested 1356 # InputSpec objects. 1357 @trackable.no_automatic_dependency_tracking 1358 def input_spec(self, value): 1359 for v in nest.flatten(value): 1360 if v is not None and not isinstance(v, InputSpec): 1361 raise TypeError('Layer input_spec must be an instance of InputSpec. ' 1362 'Got: {}'.format(v)) 1363 self._input_spec = value 1364 1365 @property 1366 def trainable_weights(self): 1367 """List of all trainable weights tracked by this layer. 1368 1369 Trainable weights are updated via gradient descent during training. 1370 1371 Returns: 1372 A list of trainable variables. 1373 """ 1374 if self.trainable: 1375 children_weights = self._gather_children_attribute('trainable_variables') 1376 return self._dedup_weights(self._trainable_weights + children_weights) 1377 else: 1378 return [] 1379 1380 @property 1381 def non_trainable_weights(self): 1382 """List of all non-trainable weights tracked by this layer. 1383 1384 Non-trainable weights are *not* updated during training. They are expected 1385 to be updated manually in `call()`. 1386 1387 Returns: 1388 A list of non-trainable variables. 1389 """ 1390 if self.trainable: 1391 children_weights = self._gather_children_attribute( 1392 'non_trainable_variables') 1393 non_trainable_weights = self._non_trainable_weights + children_weights 1394 else: 1395 children_weights = self._gather_children_attribute('variables') 1396 non_trainable_weights = ( 1397 self._trainable_weights + self._non_trainable_weights + 1398 children_weights) 1399 return self._dedup_weights(non_trainable_weights) 1400 1401 @property 1402 def weights(self): 1403 """Returns the list of all layer variables/weights. 1404 1405 Returns: 1406 A list of variables. 1407 """ 1408 return self.trainable_weights + self.non_trainable_weights 1409 1410 @property 1411 @doc_controls.do_not_generate_docs 1412 def updates(self): 1413 warnings.warn('`layer.updates` will be removed in a future version. ' 1414 'This property should not be used in TensorFlow 2.0, ' 1415 'as `updates` are applied automatically.') 1416 if keras_tensor.keras_tensors_enabled(): 1417 return [] 1418 1419 collected_updates = [] 1420 all_layers = self._flatten_layers() 1421 with backend.get_graph().as_default(): 1422 for layer in all_layers: 1423 if not layer.trainable and not layer.stateful: 1424 continue 1425 for u in layer._updates: 1426 if callable(u): 1427 u = u() 1428 collected_updates.append(u) 1429 return collected_updates 1430 1431 @property 1432 def losses(self): 1433 """List of losses added using the `add_loss()` API. 1434 1435 Variable regularization tensors are created when this property is accessed, 1436 so it is eager safe: accessing `losses` under a `tf.GradientTape` will 1437 propagate gradients back to the corresponding variables. 1438 1439 Examples: 1440 1441 >>> class MyLayer(tf.keras.layers.Layer): 1442 ... def call(self, inputs): 1443 ... self.add_loss(tf.abs(tf.reduce_mean(inputs))) 1444 ... return inputs 1445 >>> l = MyLayer() 1446 >>> l(np.ones((10, 1))) 1447 >>> l.losses 1448 [1.0] 1449 1450 >>> inputs = tf.keras.Input(shape=(10,)) 1451 >>> x = tf.keras.layers.Dense(10)(inputs) 1452 >>> outputs = tf.keras.layers.Dense(1)(x) 1453 >>> model = tf.keras.Model(inputs, outputs) 1454 >>> # Activity regularization. 1455 >>> len(model.losses) 1456 0 1457 >>> model.add_loss(tf.abs(tf.reduce_mean(x))) 1458 >>> len(model.losses) 1459 1 1460 1461 >>> inputs = tf.keras.Input(shape=(10,)) 1462 >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones') 1463 >>> x = d(inputs) 1464 >>> outputs = tf.keras.layers.Dense(1)(x) 1465 >>> model = tf.keras.Model(inputs, outputs) 1466 >>> # Weight regularization. 1467 >>> model.add_loss(lambda: tf.reduce_mean(d.kernel)) 1468 >>> model.losses 1469 [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>] 1470 1471 Returns: 1472 A list of tensors. 1473 """ 1474 collected_losses = [] 1475 for layer in self._flatten_layers(): 1476 # If any eager losses are present, we assume the model to be part of an 1477 # eager training loop (either a custom one or the one used when 1478 # `run_eagerly=True`) and so we always return just the eager losses. 1479 if layer._eager_losses: 1480 # Filter placeholder losses that may have been added by revived layers. 1481 # (see base_layer_utils for details). 1482 if (layer._eager_losses[0] is 1483 not base_layer_utils.REVIVED_LOSS_PLACEHOLDER): 1484 collected_losses.extend(layer._eager_losses) 1485 else: 1486 collected_losses.extend(layer._losses) 1487 for regularizer in layer._callable_losses: 1488 loss_tensor = regularizer() 1489 if loss_tensor is not None: 1490 collected_losses.append(loss_tensor) 1491 return collected_losses 1492 1493 def add_loss(self, losses, **kwargs): 1494 """Add loss tensor(s), potentially dependent on layer inputs. 1495 1496 Some losses (for instance, activity regularization losses) may be dependent 1497 on the inputs passed when calling a layer. Hence, when reusing the same 1498 layer on different inputs `a` and `b`, some entries in `layer.losses` may 1499 be dependent on `a` and some on `b`. This method automatically keeps track 1500 of dependencies. 1501 1502 This method can be used inside a subclassed layer or model's `call` 1503 function, in which case `losses` should be a Tensor or list of Tensors. 1504 1505 Example: 1506 1507 ```python 1508 class MyLayer(tf.keras.layers.Layer): 1509 def call(self, inputs): 1510 self.add_loss(tf.abs(tf.reduce_mean(inputs))) 1511 return inputs 1512 ``` 1513 1514 This method can also be called directly on a Functional Model during 1515 construction. In this case, any loss Tensors passed to this Model must 1516 be symbolic and be able to be traced back to the model's `Input`s. These 1517 losses become part of the model's topology and are tracked in `get_config`. 1518 1519 Example: 1520 1521 ```python 1522 inputs = tf.keras.Input(shape=(10,)) 1523 x = tf.keras.layers.Dense(10)(inputs) 1524 outputs = tf.keras.layers.Dense(1)(x) 1525 model = tf.keras.Model(inputs, outputs) 1526 # Activity regularization. 1527 model.add_loss(tf.abs(tf.reduce_mean(x))) 1528 ``` 1529 1530 If this is not the case for your loss (if, for example, your loss references 1531 a `Variable` of one of the model's layers), you can wrap your loss in a 1532 zero-argument lambda. These losses are not tracked as part of the model's 1533 topology since they can't be serialized. 1534 1535 Example: 1536 1537 ```python 1538 inputs = tf.keras.Input(shape=(10,)) 1539 d = tf.keras.layers.Dense(10) 1540 x = d(inputs) 1541 outputs = tf.keras.layers.Dense(1)(x) 1542 model = tf.keras.Model(inputs, outputs) 1543 # Weight regularization. 1544 model.add_loss(lambda: tf.reduce_mean(d.kernel)) 1545 ``` 1546 1547 Args: 1548 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses 1549 may also be zero-argument callables which create a loss tensor. 1550 **kwargs: Additional keyword arguments for backward compatibility. 1551 Accepted values: 1552 inputs - Deprecated, will be automatically inferred. 1553 """ 1554 kwargs.pop('inputs', None) 1555 if kwargs: 1556 raise TypeError('Unknown keyword arguments: %s' % (kwargs.keys(),)) 1557 1558 def _tag_callable(loss): 1559 """Tags callable loss tensor as `_unconditional_loss`.""" 1560 if callable(loss): 1561 # We run the loss without autocasting, as regularizers are often 1562 # numerically unstable in float16. 1563 with autocast_variable.enable_auto_cast_variables(None): 1564 loss = loss() 1565 if loss is None: 1566 return None # Will be filtered out when computing the .losses property 1567 if not tensor_util.is_tf_type(loss): 1568 loss = ops.convert_to_tensor_v2_with_dispatch( 1569 loss, dtype=backend.floatx()) 1570 loss._unconditional_loss = True # pylint: disable=protected-access 1571 return loss 1572 1573 losses = nest.flatten(losses) 1574 1575 callable_losses = [] 1576 eager_losses = [] 1577 symbolic_losses = [] 1578 for loss in losses: 1579 if callable(loss): 1580 callable_losses.append(functools.partial(_tag_callable, loss)) 1581 continue 1582 if loss is None: 1583 continue 1584 if not tensor_util.is_tf_type(loss) and not isinstance( 1585 loss, keras_tensor.KerasTensor): 1586 loss = ops.convert_to_tensor_v2_with_dispatch( 1587 loss, dtype=backend.floatx()) 1588 # TF Functions should take the eager path. 1589 if ((tf_utils.is_symbolic_tensor(loss) or 1590 isinstance(loss, keras_tensor.KerasTensor)) and 1591 not base_layer_utils.is_in_tf_function()): 1592 symbolic_losses.append(loss) 1593 elif tensor_util.is_tf_type(loss): 1594 eager_losses.append(loss) 1595 1596 self._callable_losses.extend(callable_losses) 1597 1598 in_call_context = base_layer_utils.call_context().in_call 1599 if eager_losses and not in_call_context: 1600 raise ValueError( 1601 'Expected a symbolic Tensors or a callable for the loss value. ' 1602 'Please wrap your loss computation in a zero argument `lambda`.') 1603 1604 self._eager_losses.extend(eager_losses) 1605 1606 if in_call_context and not keras_tensor.keras_tensors_enabled(): 1607 for symbolic_loss in symbolic_losses: 1608 self._losses.append(symbolic_loss) 1609 else: 1610 for symbolic_loss in symbolic_losses: 1611 if getattr(self, '_is_graph_network', False): 1612 self._graph_network_add_loss(symbolic_loss) 1613 else: 1614 # Possible a loss was added in a Layer's `build`. 1615 self._losses.append(symbolic_loss) 1616 1617 def _clear_losses(self): 1618 """Used every step in eager to reset losses.""" 1619 # Set to thread local directly to avoid Layer.__setattr__ overhead. 1620 if not getattr(self, '_self_tracked_trackables', 1621 None): # Fast path for single Layer. 1622 self._thread_local._eager_losses = [] 1623 else: 1624 for layer in self._flatten_layers(): 1625 layer._thread_local._eager_losses = [] 1626 1627 @property 1628 def metrics(self): 1629 """List of metrics added using the `add_metric()` API. 1630 1631 Example: 1632 1633 >>> input = tf.keras.layers.Input(shape=(3,)) 1634 >>> d = tf.keras.layers.Dense(2) 1635 >>> output = d(input) 1636 >>> d.add_metric(tf.reduce_max(output), name='max') 1637 >>> d.add_metric(tf.reduce_min(output), name='min') 1638 >>> [m.name for m in d.metrics] 1639 ['max', 'min'] 1640 1641 Returns: 1642 A list of `Metric` objects. 1643 """ 1644 collected_metrics = [] 1645 for layer in self._flatten_layers(): 1646 with layer._metrics_lock: 1647 collected_metrics.extend(layer._metrics) 1648 return collected_metrics 1649 1650 def add_metric(self, value, name=None, **kwargs): 1651 """Adds metric tensor to the layer. 1652 1653 This method can be used inside the `call()` method of a subclassed layer 1654 or model. 1655 1656 ```python 1657 class MyMetricLayer(tf.keras.layers.Layer): 1658 def __init__(self): 1659 super(MyMetricLayer, self).__init__(name='my_metric_layer') 1660 self.mean = tf.keras.metrics.Mean(name='metric_1') 1661 1662 def call(self, inputs): 1663 self.add_metric(self.mean(x)) 1664 self.add_metric(tf.reduce_sum(x), name='metric_2') 1665 return inputs 1666 ``` 1667 1668 This method can also be called directly on a Functional Model during 1669 construction. In this case, any tensor passed to this Model must 1670 be symbolic and be able to be traced back to the model's `Input`s. These 1671 metrics become part of the model's topology and are tracked when you 1672 save the model via `save()`. 1673 1674 ```python 1675 inputs = tf.keras.Input(shape=(10,)) 1676 x = tf.keras.layers.Dense(10)(inputs) 1677 outputs = tf.keras.layers.Dense(1)(x) 1678 model = tf.keras.Model(inputs, outputs) 1679 model.add_metric(math_ops.reduce_sum(x), name='metric_1') 1680 ``` 1681 1682 Note: Calling `add_metric()` with the result of a metric object on a 1683 Functional Model, as shown in the example below, is not supported. This is 1684 because we cannot trace the metric result tensor back to the model's inputs. 1685 1686 ```python 1687 inputs = tf.keras.Input(shape=(10,)) 1688 x = tf.keras.layers.Dense(10)(inputs) 1689 outputs = tf.keras.layers.Dense(1)(x) 1690 model = tf.keras.Model(inputs, outputs) 1691 model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1') 1692 ``` 1693 1694 Args: 1695 value: Metric tensor. 1696 name: String metric name. 1697 **kwargs: Additional keyword arguments for backward compatibility. 1698 Accepted values: 1699 `aggregation` - When the `value` tensor provided is not the result of 1700 calling a `keras.Metric` instance, it will be aggregated by default 1701 using a `keras.Metric.Mean`. 1702 """ 1703 kwargs_keys = list(kwargs.keys()) 1704 if (len(kwargs_keys) > 1 or 1705 (len(kwargs_keys) == 1 and kwargs_keys[0] != 'aggregation')): 1706 raise TypeError('Unknown keyword arguments: ', str(kwargs.keys())) 1707 1708 from_metric_obj = hasattr(value, '_metric_obj') 1709 if keras_tensor.keras_tensors_enabled(): 1710 is_symbolic = isinstance(value, keras_tensor.KerasTensor) 1711 else: 1712 is_symbolic = tf_utils.is_symbolic_tensor(value) 1713 in_call_context = base_layer_utils.call_context().in_call 1714 1715 if name is None and not from_metric_obj: 1716 # Eg. `self.add_metric(math_ops.reduce_sum(x))` 1717 # In eager mode, we use metric name to lookup a metric. Without a name, 1718 # a new Mean metric wrapper will be created on every model/layer call. 1719 # So, we raise an error when no name is provided. 1720 # We will do the same for symbolic mode for consistency although a name 1721 # will be generated if no name is provided. 1722 1723 # We will not raise this error in the foll use case for the sake of 1724 # consistency as name in provided in the metric constructor. 1725 # mean = metrics.Mean(name='my_metric') 1726 # model.add_metric(mean(outputs)) 1727 raise ValueError('Please provide a name for your metric like ' 1728 '`self.add_metric(tf.reduce_sum(inputs), ' 1729 'name=\'mean_activation\')`') 1730 elif from_metric_obj: 1731 name = value._metric_obj.name 1732 1733 if not in_call_context and not is_symbolic: 1734 raise ValueError('Expected a symbolic Tensor for the metric value, ' 1735 'received: ' + str(value)) 1736 1737 # If a metric was added in a Layer's `call` or `build`. 1738 if in_call_context or not getattr(self, '_is_graph_network', False): 1739 # TF Function path should take the eager path. 1740 1741 # If the given metric is available in `metrics` list we just update state 1742 # on it, otherwise we create a new metric instance and 1743 # add it to the `metrics` list. 1744 metric_obj = getattr(value, '_metric_obj', None) 1745 # Tensors that come from a Metric object already updated the Metric state. 1746 should_update_state = not metric_obj 1747 name = metric_obj.name if metric_obj else name 1748 1749 with self._metrics_lock: 1750 match = self._get_existing_metric(name) 1751 if match: 1752 metric_obj = match 1753 elif metric_obj: 1754 self._metrics.append(metric_obj) 1755 else: 1756 # Build the metric object with the value's dtype if it defines one 1757 metric_obj = metrics_mod.Mean( 1758 name=name, dtype=getattr(value, 'dtype', None)) 1759 self._metrics.append(metric_obj) 1760 1761 if should_update_state: 1762 metric_obj(value) 1763 else: 1764 if from_metric_obj: 1765 raise ValueError('Using the result of calling a `Metric` object ' 1766 'when calling `add_metric` on a Functional ' 1767 'Model is not supported. Please pass the ' 1768 'Tensor to monitor directly.') 1769 1770 # Insert layers into the Keras Graph Network. 1771 aggregation = None if from_metric_obj else 'mean' 1772 self._graph_network_add_metric(value, aggregation, name) 1773 1774 @doc_controls.do_not_doc_inheritable 1775 def add_update(self, updates, inputs=None): 1776 """Add update op(s), potentially dependent on layer inputs. 1777 1778 Weight updates (for instance, the updates of the moving mean and variance 1779 in a BatchNormalization layer) may be dependent on the inputs passed 1780 when calling a layer. Hence, when reusing the same layer on 1781 different inputs `a` and `b`, some entries in `layer.updates` may be 1782 dependent on `a` and some on `b`. This method automatically keeps track 1783 of dependencies. 1784 1785 This call is ignored when eager execution is enabled (in that case, variable 1786 updates are run on the fly and thus do not need to be tracked for later 1787 execution). 1788 1789 Args: 1790 updates: Update op, or list/tuple of update ops, or zero-arg callable 1791 that returns an update op. A zero-arg callable should be passed in 1792 order to disable running the updates by setting `trainable=False` 1793 on this Layer, when executing in Eager mode. 1794 inputs: Deprecated, will be automatically inferred. 1795 """ 1796 if inputs is not None: 1797 tf_logging.warning( 1798 '`add_update` `inputs` kwarg has been deprecated. You no longer need ' 1799 'to pass a value to `inputs` as it is being automatically inferred.') 1800 call_context = base_layer_utils.call_context() 1801 # No need to run updates during Functional API construction. 1802 if call_context.in_keras_graph: 1803 return 1804 1805 # Callable updates are disabled by setting `trainable=False`. 1806 if not call_context.frozen: 1807 for update in nest.flatten(updates): 1808 if callable(update): 1809 update() # pylint: disable=not-callable 1810 1811 def set_weights(self, weights): 1812 """Sets the weights of the layer, from Numpy arrays. 1813 1814 The weights of a layer represent the state of the layer. This function 1815 sets the weight values from numpy arrays. The weight values should be 1816 passed in the order they are created by the layer. Note that the layer's 1817 weights must be instantiated before calling this function by calling 1818 the layer. 1819 1820 For example, a Dense layer returns a list of two values-- per-output 1821 weights and the bias value. These can be used to set the weights of another 1822 Dense layer: 1823 1824 >>> a = tf.keras.layers.Dense(1, 1825 ... kernel_initializer=tf.constant_initializer(1.)) 1826 >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]])) 1827 >>> a.get_weights() 1828 [array([[1.], 1829 [1.], 1830 [1.]], dtype=float32), array([0.], dtype=float32)] 1831 >>> b = tf.keras.layers.Dense(1, 1832 ... kernel_initializer=tf.constant_initializer(2.)) 1833 >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]])) 1834 >>> b.get_weights() 1835 [array([[2.], 1836 [2.], 1837 [2.]], dtype=float32), array([0.], dtype=float32)] 1838 >>> b.set_weights(a.get_weights()) 1839 >>> b.get_weights() 1840 [array([[1.], 1841 [1.], 1842 [1.]], dtype=float32), array([0.], dtype=float32)] 1843 1844 Args: 1845 weights: a list of Numpy arrays. The number 1846 of arrays and their shape must match 1847 number of the dimensions of the weights 1848 of the layer (i.e. it should match the 1849 output of `get_weights`). 1850 1851 Raises: 1852 ValueError: If the provided weights list does not match the 1853 layer's specifications. 1854 """ 1855 params = self.weights 1856 1857 expected_num_weights = 0 1858 for param in params: 1859 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1860 expected_num_weights += param.num_tensors 1861 else: 1862 expected_num_weights += 1 1863 1864 if expected_num_weights != len(weights): 1865 raise ValueError( 1866 'You called `set_weights(weights)` on layer "%s" ' 1867 'with a weight list of length %s, but the layer was ' 1868 'expecting %s weights. Provided weights: %s...' % 1869 (self.name, len(weights), expected_num_weights, str(weights)[:50])) 1870 1871 weight_index = 0 1872 weight_value_tuples = [] 1873 for param in params: 1874 if isinstance(param, base_layer_utils.TrackableWeightHandler): 1875 num_tensors = param.num_tensors 1876 tensors = weights[weight_index:weight_index + num_tensors] 1877 param.set_weights(tensors) 1878 weight_index += num_tensors 1879 else: 1880 weight = weights[weight_index] 1881 ref_shape = param.shape 1882 if not ref_shape.is_compatible_with(weight.shape): 1883 raise ValueError( 1884 'Layer weight shape %s not compatible with provided weight ' 1885 'shape %s' % (ref_shape, weight.shape)) 1886 weight_value_tuples.append((param, weight)) 1887 weight_index += 1 1888 1889 backend.batch_set_value(weight_value_tuples) 1890 1891 def get_weights(self): 1892 """Returns the current weights of the layer. 1893 1894 The weights of a layer represent the state of the layer. This function 1895 returns both trainable and non-trainable weight values associated with this 1896 layer as a list of Numpy arrays, which can in turn be used to load state 1897 into similarly parameterized layers. 1898 1899 For example, a Dense layer returns a list of two values-- per-output 1900 weights and the bias value. These can be used to set the weights of another 1901 Dense layer: 1902 1903 >>> a = tf.keras.layers.Dense(1, 1904 ... kernel_initializer=tf.constant_initializer(1.)) 1905 >>> a_out = a(tf.convert_to_tensor([[1., 2., 3.]])) 1906 >>> a.get_weights() 1907 [array([[1.], 1908 [1.], 1909 [1.]], dtype=float32), array([0.], dtype=float32)] 1910 >>> b = tf.keras.layers.Dense(1, 1911 ... kernel_initializer=tf.constant_initializer(2.)) 1912 >>> b_out = b(tf.convert_to_tensor([[10., 20., 30.]])) 1913 >>> b.get_weights() 1914 [array([[2.], 1915 [2.], 1916 [2.]], dtype=float32), array([0.], dtype=float32)] 1917 >>> b.set_weights(a.get_weights()) 1918 >>> b.get_weights() 1919 [array([[1.], 1920 [1.], 1921 [1.]], dtype=float32), array([0.], dtype=float32)] 1922 1923 Returns: 1924 Weights values as a list of numpy arrays. 1925 """ 1926 weights = self.weights 1927 output_weights = [] 1928 for weight in weights: 1929 if isinstance(weight, base_layer_utils.TrackableWeightHandler): 1930 output_weights.extend(weight.get_tensors()) 1931 else: 1932 output_weights.append(weight) 1933 return backend.batch_get_value(output_weights) 1934 1935 @doc_controls.do_not_generate_docs 1936 def get_updates_for(self, inputs): 1937 """Deprecated, do NOT use! 1938 1939 Retrieves updates relevant to a specific set of inputs. 1940 1941 Args: 1942 inputs: Input tensor or list/tuple of input tensors. 1943 1944 Returns: 1945 List of update ops of the layer that depend on `inputs`. 1946 """ 1947 warnings.warn('`layer.get_updates_for` is deprecated and ' 1948 'will be removed in a future version. ' 1949 'Please use `layer.updates` method instead.') 1950 return self.updates 1951 1952 @doc_controls.do_not_generate_docs 1953 def get_losses_for(self, inputs): 1954 """Deprecated, do NOT use! 1955 1956 Retrieves losses relevant to a specific set of inputs. 1957 1958 Args: 1959 inputs: Input tensor or list/tuple of input tensors. 1960 1961 Returns: 1962 List of loss tensors of the layer that depend on `inputs`. 1963 """ 1964 warnings.warn('`layer.get_losses_for` is deprecated and ' 1965 'will be removed in a future version. ' 1966 'Please use `layer.losses` instead.') 1967 return self.losses 1968 1969 @doc_controls.do_not_doc_inheritable 1970 def get_input_mask_at(self, node_index): 1971 """Retrieves the input mask tensor(s) of a layer at a given node. 1972 1973 Args: 1974 node_index: Integer, index of the node 1975 from which to retrieve the attribute. 1976 E.g. `node_index=0` will correspond to the 1977 first time the layer was called. 1978 1979 Returns: 1980 A mask tensor 1981 (or list of tensors if the layer has multiple inputs). 1982 """ 1983 inputs = self.get_input_at(node_index) 1984 if isinstance(inputs, list): 1985 return [getattr(x, '_keras_mask', None) for x in inputs] 1986 else: 1987 return getattr(inputs, '_keras_mask', None) 1988 1989 @doc_controls.do_not_doc_inheritable 1990 def get_output_mask_at(self, node_index): 1991 """Retrieves the output mask tensor(s) of a layer at a given node. 1992 1993 Args: 1994 node_index: Integer, index of the node 1995 from which to retrieve the attribute. 1996 E.g. `node_index=0` will correspond to the 1997 first time the layer was called. 1998 1999 Returns: 2000 A mask tensor 2001 (or list of tensors if the layer has multiple outputs). 2002 """ 2003 output = self.get_output_at(node_index) 2004 if isinstance(output, list): 2005 return [getattr(x, '_keras_mask', None) for x in output] 2006 else: 2007 return getattr(output, '_keras_mask', None) 2008 2009 @property 2010 @doc_controls.do_not_doc_inheritable 2011 def input_mask(self): 2012 """Retrieves the input mask tensor(s) of a layer. 2013 2014 Only applicable if the layer has exactly one inbound node, 2015 i.e. if it is connected to one incoming layer. 2016 2017 Returns: 2018 Input mask tensor (potentially None) or list of input 2019 mask tensors. 2020 2021 Raises: 2022 AttributeError: if the layer is connected to 2023 more than one incoming layers. 2024 """ 2025 inputs = self.input 2026 if isinstance(inputs, list): 2027 return [getattr(x, '_keras_mask', None) for x in inputs] 2028 else: 2029 return getattr(inputs, '_keras_mask', None) 2030 2031 @property 2032 @doc_controls.do_not_doc_inheritable 2033 def output_mask(self): 2034 """Retrieves the output mask tensor(s) of a layer. 2035 2036 Only applicable if the layer has exactly one inbound node, 2037 i.e. if it is connected to one incoming layer. 2038 2039 Returns: 2040 Output mask tensor (potentially None) or list of output 2041 mask tensors. 2042 2043 Raises: 2044 AttributeError: if the layer is connected to 2045 more than one incoming layers. 2046 """ 2047 output = self.output 2048 if isinstance(output, list): 2049 return [getattr(x, '_keras_mask', None) for x in output] 2050 else: 2051 return getattr(output, '_keras_mask', None) 2052 2053 @doc_controls.do_not_doc_inheritable 2054 def get_input_shape_at(self, node_index): 2055 """Retrieves the input shape(s) of a layer at a given node. 2056 2057 Args: 2058 node_index: Integer, index of the node 2059 from which to retrieve the attribute. 2060 E.g. `node_index=0` will correspond to the 2061 first time the layer was called. 2062 2063 Returns: 2064 A shape tuple 2065 (or list of shape tuples if the layer has multiple inputs). 2066 2067 Raises: 2068 RuntimeError: If called in Eager mode. 2069 """ 2070 return self._get_node_attribute_at_index(node_index, 'input_shapes', 2071 'input shape') 2072 2073 @doc_controls.do_not_doc_inheritable 2074 def get_output_shape_at(self, node_index): 2075 """Retrieves the output shape(s) of a layer at a given node. 2076 2077 Args: 2078 node_index: Integer, index of the node 2079 from which to retrieve the attribute. 2080 E.g. `node_index=0` will correspond to the 2081 first time the layer was called. 2082 2083 Returns: 2084 A shape tuple 2085 (or list of shape tuples if the layer has multiple outputs). 2086 2087 Raises: 2088 RuntimeError: If called in Eager mode. 2089 """ 2090 return self._get_node_attribute_at_index(node_index, 'output_shapes', 2091 'output shape') 2092 2093 @doc_controls.do_not_doc_inheritable 2094 def get_input_at(self, node_index): 2095 """Retrieves the input tensor(s) of a layer at a given node. 2096 2097 Args: 2098 node_index: Integer, index of the node 2099 from which to retrieve the attribute. 2100 E.g. `node_index=0` will correspond to the 2101 first input node of the layer. 2102 2103 Returns: 2104 A tensor (or list of tensors if the layer has multiple inputs). 2105 2106 Raises: 2107 RuntimeError: If called in Eager mode. 2108 """ 2109 return self._get_node_attribute_at_index(node_index, 'input_tensors', 2110 'input') 2111 2112 @doc_controls.do_not_doc_inheritable 2113 def get_output_at(self, node_index): 2114 """Retrieves the output tensor(s) of a layer at a given node. 2115 2116 Args: 2117 node_index: Integer, index of the node 2118 from which to retrieve the attribute. 2119 E.g. `node_index=0` will correspond to the 2120 first output node of the layer. 2121 2122 Returns: 2123 A tensor (or list of tensors if the layer has multiple outputs). 2124 2125 Raises: 2126 RuntimeError: If called in Eager mode. 2127 """ 2128 return self._get_node_attribute_at_index(node_index, 'output_tensors', 2129 'output') 2130 2131 @property 2132 def input(self): 2133 """Retrieves the input tensor(s) of a layer. 2134 2135 Only applicable if the layer has exactly one input, 2136 i.e. if it is connected to one incoming layer. 2137 2138 Returns: 2139 Input tensor or list of input tensors. 2140 2141 Raises: 2142 RuntimeError: If called in Eager mode. 2143 AttributeError: If no inbound nodes are found. 2144 """ 2145 if not self._inbound_nodes: 2146 raise AttributeError('Layer ' + self.name + 2147 ' is not connected, no input to return.') 2148 return self._get_node_attribute_at_index(0, 'input_tensors', 'input') 2149 2150 @property 2151 def output(self): 2152 """Retrieves the output tensor(s) of a layer. 2153 2154 Only applicable if the layer has exactly one output, 2155 i.e. if it is connected to one incoming layer. 2156 2157 Returns: 2158 Output tensor or list of output tensors. 2159 2160 Raises: 2161 AttributeError: if the layer is connected to more than one incoming 2162 layers. 2163 RuntimeError: if called in Eager mode. 2164 """ 2165 if not self._inbound_nodes: 2166 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') 2167 return self._get_node_attribute_at_index(0, 'output_tensors', 'output') 2168 2169 @property 2170 @doc_controls.do_not_doc_inheritable 2171 def input_shape(self): 2172 """Retrieves the input shape(s) of a layer. 2173 2174 Only applicable if the layer has exactly one input, 2175 i.e. if it is connected to one incoming layer, or if all inputs 2176 have the same shape. 2177 2178 Returns: 2179 Input shape, as an integer shape tuple 2180 (or list of shape tuples, one tuple per input tensor). 2181 2182 Raises: 2183 AttributeError: if the layer has no defined input_shape. 2184 RuntimeError: if called in Eager mode. 2185 """ 2186 if not self._inbound_nodes: 2187 raise AttributeError('The layer has never been called ' 2188 'and thus has no defined input shape.') 2189 all_input_shapes = set( 2190 [str(node.input_shapes) for node in self._inbound_nodes]) 2191 if len(all_input_shapes) == 1: 2192 return self._inbound_nodes[0].input_shapes 2193 else: 2194 raise AttributeError('The layer "' + str(self.name) + 2195 ' has multiple inbound nodes, ' 2196 'with different input shapes. Hence ' 2197 'the notion of "input shape" is ' 2198 'ill-defined for the layer. ' 2199 'Use `get_input_shape_at(node_index)` ' 2200 'instead.') 2201 2202 def count_params(self): 2203 """Count the total number of scalars composing the weights. 2204 2205 Returns: 2206 An integer count. 2207 2208 Raises: 2209 ValueError: if the layer isn't yet built 2210 (in which case its weights aren't yet defined). 2211 """ 2212 if not self.built: 2213 if getattr(self, '_is_graph_network', False): 2214 with tf_utils.maybe_init_scope(self): 2215 self._maybe_build(self.inputs) 2216 else: 2217 raise ValueError('You tried to call `count_params` on ' + self.name + 2218 ', but the layer isn\'t built. ' 2219 'You can build it manually via: `' + self.name + 2220 '.build(batch_input_shape)`.') 2221 return layer_utils.count_params(self.weights) 2222 2223 @property 2224 @doc_controls.do_not_doc_inheritable 2225 def output_shape(self): 2226 """Retrieves the output shape(s) of a layer. 2227 2228 Only applicable if the layer has one output, 2229 or if all outputs have the same shape. 2230 2231 Returns: 2232 Output shape, as an integer shape tuple 2233 (or list of shape tuples, one tuple per output tensor). 2234 2235 Raises: 2236 AttributeError: if the layer has no defined output shape. 2237 RuntimeError: if called in Eager mode. 2238 """ 2239 if not self._inbound_nodes: 2240 raise AttributeError('The layer has never been called ' 2241 'and thus has no defined output shape.') 2242 all_output_shapes = set( 2243 [str(node.output_shapes) for node in self._inbound_nodes]) 2244 if len(all_output_shapes) == 1: 2245 return self._inbound_nodes[0].output_shapes 2246 else: 2247 raise AttributeError('The layer "%s"' 2248 ' has multiple inbound nodes, ' 2249 'with different output shapes. Hence ' 2250 'the notion of "output shape" is ' 2251 'ill-defined for the layer. ' 2252 'Use `get_output_shape_at(node_index)` ' 2253 'instead.' % self.name) 2254 2255 @property 2256 @doc_controls.do_not_doc_inheritable 2257 def inbound_nodes(self): 2258 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 2259 return self._inbound_nodes 2260 2261 @property 2262 @doc_controls.do_not_doc_inheritable 2263 def outbound_nodes(self): 2264 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 2265 return self._outbound_nodes 2266 2267 ############################################################################## 2268 # Methods & attributes below are public aliases of other methods. # 2269 ############################################################################## 2270 2271 @doc_controls.do_not_doc_inheritable 2272 def apply(self, inputs, *args, **kwargs): 2273 """Deprecated, do NOT use! 2274 2275 This is an alias of `self.__call__`. 2276 2277 Args: 2278 inputs: Input tensor(s). 2279 *args: additional positional arguments to be passed to `self.call`. 2280 **kwargs: additional keyword arguments to be passed to `self.call`. 2281 2282 Returns: 2283 Output tensor(s). 2284 """ 2285 warnings.warn('`layer.apply` is deprecated and ' 2286 'will be removed in a future version. ' 2287 'Please use `layer.__call__` method instead.') 2288 return self.__call__(inputs, *args, **kwargs) 2289 2290 @doc_controls.do_not_doc_inheritable 2291 def add_variable(self, *args, **kwargs): 2292 """Deprecated, do NOT use! Alias for `add_weight`.""" 2293 warnings.warn('`layer.add_variable` is deprecated and ' 2294 'will be removed in a future version. ' 2295 'Please use `layer.add_weight` method instead.') 2296 return self.add_weight(*args, **kwargs) 2297 2298 @property 2299 @doc_controls.do_not_generate_docs 2300 def variables(self): 2301 """Returns the list of all layer variables/weights. 2302 2303 Alias of `self.weights`. 2304 2305 Note: This will not track the weights of nested `tf.Modules` that are not 2306 themselves Keras layers. 2307 2308 Returns: 2309 A list of variables. 2310 """ 2311 return self.weights 2312 2313 @property 2314 @doc_controls.do_not_generate_docs 2315 def trainable_variables(self): 2316 return self.trainable_weights 2317 2318 @property 2319 @doc_controls.do_not_generate_docs 2320 def non_trainable_variables(self): 2321 return self.non_trainable_weights 2322 2323 ############################################################################## 2324 # Methods & attributes below are all private and only used by the framework. # 2325 ############################################################################## 2326 2327 @property 2328 def _inbound_nodes(self): 2329 return self._inbound_nodes_value 2330 2331 @_inbound_nodes.setter 2332 @trackable.no_automatic_dependency_tracking 2333 def _inbound_nodes(self, value): 2334 self._inbound_nodes_value = value 2335 2336 @property 2337 def _outbound_nodes(self): 2338 return self._outbound_nodes_value 2339 2340 @_outbound_nodes.setter 2341 @trackable.no_automatic_dependency_tracking 2342 def _outbound_nodes(self, value): 2343 self._outbound_nodes_value = value 2344 2345 def _set_dtype_policy(self, dtype): 2346 """Sets self._dtype_policy.""" 2347 if isinstance(dtype, policy.Policy): 2348 self._dtype_policy = dtype 2349 elif isinstance(dtype, dict): 2350 self._dtype_policy = policy.deserialize(dtype) 2351 elif dtype: 2352 self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name) 2353 else: 2354 self._dtype_policy = policy.global_policy() 2355 if (self._dtype_policy.name == 'mixed_float16' and 2356 not loss_scale_optimizer.strategy_supports_loss_scaling()): 2357 # Although only loss scaling doesn't support certain strategies, to avoid 2358 # confusion, we disallow the 'mixed_float16' policy with unsupported 2359 # strategies. This is because 'mixed_float16' requires loss scaling for 2360 # numeric stability. 2361 strategy = ds_context.get_strategy() 2362 raise ValueError('Mixed precision is not supported with the ' 2363 'tf.distribute.Strategy: %s. Either stop using mixed ' 2364 'precision by removing the use of the "%s" policy or ' 2365 'use a different Strategy, e.g. a MirroredStrategy.' % 2366 (strategy.__class__.__name__, self._dtype_policy.name)) 2367 2368 # Performance optimization: cache the compute dtype as a Dtype object or 2369 # None, so that str to Dtype conversion doesn't happen in Layer.__call__. 2370 # TODO(b/157486353): Investigate returning DTypes in Policy. 2371 if self._dtype_policy.compute_dtype: 2372 self._compute_dtype_object = dtypes.as_dtype( 2373 self._dtype_policy.compute_dtype) 2374 else: 2375 self._compute_dtype_object = None 2376 2377 @property 2378 def dtype_policy(self): 2379 """The dtype policy associated with this layer. 2380 2381 This is an instance of a `tf.keras.mixed_precision.Policy`. 2382 """ 2383 return self._dtype_policy 2384 2385 @property 2386 def compute_dtype(self): 2387 """The dtype of the layer's computations. 2388 2389 This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless 2390 mixed precision is used, this is the same as `Layer.dtype`, the dtype of 2391 the weights. 2392 2393 Layers automatically cast their inputs to the compute dtype, which causes 2394 computations and the output to be in the compute dtype as well. This is done 2395 by the base Layer class in `Layer.__call__`, so you do not have to insert 2396 these casts if implementing your own layer. 2397 2398 Layers often perform certain internal computations in higher precision when 2399 `compute_dtype` is float16 or bfloat16 for numeric stability. The output 2400 will still typically be float16 or bfloat16 in such cases. 2401 2402 Returns: 2403 The layer's compute dtype. 2404 """ 2405 return self._dtype_policy.compute_dtype 2406 2407 @property 2408 def _compute_dtype(self): 2409 """Deprecated alias of `compute_dtype`.""" 2410 return self._dtype_policy.compute_dtype 2411 2412 @property 2413 def variable_dtype(self): 2414 """Alias of `Layer.dtype`, the dtype of the weights.""" 2415 return self.dtype 2416 2417 def _maybe_cast_inputs(self, inputs, input_list=None): 2418 """Maybe casts the inputs to the compute dtype. 2419 2420 If self._compute_dtype is floating-point, and self_autocast is True, 2421 floating-point inputs are casted to self._compute_dtype. 2422 2423 Args: 2424 inputs: Input tensor, or structure of input tensors. 2425 input_list: Flat list of input tensors. 2426 2427 Returns: 2428 `inputs`, but tensors may have been casted to self._compute_dtype 2429 """ 2430 if not input_list: 2431 input_list = nest.flatten(inputs) 2432 2433 compute_dtype_object = self._compute_dtype_object 2434 should_autocast = ( 2435 self._autocast and compute_dtype_object and 2436 compute_dtype_object.is_floating) 2437 2438 if (should_autocast and 2439 any(map(self._should_cast_single_input, input_list))): 2440 # Only perform expensive `nest` operation when needed. 2441 return nest.map_structure(self._cast_single_input, inputs) 2442 else: 2443 return inputs 2444 2445 def _should_cast_single_input(self, x): 2446 if isinstance(x, _AUTOCAST_TYPES): 2447 return (self._compute_dtype_object and 2448 x.dtype != self._compute_dtype_object and x.dtype.is_floating) 2449 return False 2450 2451 def _cast_single_input(self, x): 2452 """Cast a single Tensor or TensorSpec to the compute dtype.""" 2453 if self._should_cast_single_input(x): 2454 return math_ops.cast(x, self._compute_dtype_object) 2455 else: 2456 return x 2457 2458 # _dtype used to be an attribute set in the constructor. We still expose it 2459 # because some clients still use it. 2460 # TODO(reedwm): Deprecate, then remove the _dtype property. 2461 @property 2462 def _dtype(self): 2463 # This is equivalent to returning self.dtype . We do not return self.dtype 2464 # as it would cause infinite recursion in a few subclasses, which override 2465 # "dtype" to return self._dtype. 2466 return self._dtype_policy.variable_dtype 2467 2468 @_dtype.setter 2469 def _dtype(self, value): 2470 value = dtypes.as_dtype(value).name 2471 self._set_dtype_policy(policy.Policy(value)) 2472 2473 def _name_scope(self): 2474 if not tf2.enabled(): 2475 return self.name 2476 name_scope = self.name 2477 current_name_scope = ops.get_name_scope() 2478 if current_name_scope: 2479 name_scope = current_name_scope + '/' + name_scope 2480 if name_scope: 2481 # Note that the trailing `/` prevents autogenerated 2482 # numerical suffixes to get appended. It will also fully reset 2483 # nested name scope (i.e. the outer name scope has no effect). 2484 name_scope += '/' 2485 return name_scope 2486 2487 def _init_set_name(self, name, zero_based=True): 2488 if not name: 2489 self._name = backend.unique_object_name( 2490 generic_utils.to_snake_case(self.__class__.__name__), 2491 zero_based=zero_based) 2492 else: 2493 backend.observe_object_name(name) 2494 self._name = name 2495 2496 def _get_existing_metric(self, name=None): 2497 match = [m for m in self._metrics if m.name == name] 2498 if not match: 2499 return 2500 if len(match) > 1: 2501 raise ValueError( 2502 'Please provide different names for the metrics you have added. ' 2503 'We found {} metrics with the name: "{}"'.format(len(match), name)) 2504 return match[0] 2505 2506 def _handle_weight_regularization(self, name, variable, regularizer): 2507 """Create lambdas which compute regularization losses.""" 2508 2509 def _loss_for_variable(v): 2510 """Creates a regularization loss `Tensor` for variable `v`.""" 2511 with backend.name_scope(name + '/Regularizer'): 2512 regularization = regularizer(v) 2513 return regularization 2514 2515 if base_layer_utils.is_split_variable(variable): 2516 for v in variable: 2517 self.add_loss(functools.partial(_loss_for_variable, v)) 2518 else: 2519 self.add_loss(functools.partial(_loss_for_variable, variable)) 2520 2521 def _handle_activity_regularization(self, inputs, outputs): 2522 # Apply activity regularization. 2523 # Note that it should be applied every time the layer creates a new 2524 # output, since it is output-specific. 2525 if self._activity_regularizer: 2526 output_list = nest.flatten(outputs) 2527 with backend.name_scope('ActivityRegularizer'): 2528 for output in output_list: 2529 activity_loss = self._activity_regularizer(output) 2530 batch_size = math_ops.cast( 2531 array_ops.shape(output)[0], activity_loss.dtype) 2532 # Make activity regularization strength batch-agnostic. 2533 mean_activity_loss = activity_loss / batch_size 2534 self.add_loss(mean_activity_loss) 2535 2536 def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph): 2537 # Many `Layer`s don't need to call `compute_mask`. 2538 # This method is optimized to do as little work as needed for the common 2539 # case. 2540 if not self._supports_masking: 2541 return 2542 2543 flat_outputs = nest.flatten(outputs) 2544 2545 mask_already_computed = ( 2546 getattr(self, '_compute_output_and_mask_jointly', False) or 2547 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) 2548 if mask_already_computed: 2549 if build_graph: 2550 self._set_mask_keras_history_checked(flat_outputs) 2551 return 2552 2553 output_masks = self.compute_mask(inputs, previous_mask) 2554 if output_masks is None: 2555 return 2556 2557 flat_masks = nest.flatten(output_masks) 2558 for tensor, mask in zip(flat_outputs, flat_masks): 2559 try: 2560 tensor._keras_mask = mask 2561 except AttributeError: 2562 # C Type such as np.ndarray. 2563 pass 2564 2565 if build_graph: 2566 self._set_mask_keras_history_checked(flat_outputs) 2567 2568 def _set_mask_keras_history_checked(self, flat_outputs): 2569 for output in flat_outputs: 2570 if getattr(output, '_keras_mask', None) is not None: 2571 # Do not track masks for `TensorFlowOpLayer` construction. 2572 output._keras_mask._keras_history_checked = True 2573 2574 def _get_input_masks(self, inputs, input_list, args, kwargs): 2575 if not self._supports_masking and not self._expects_mask_arg: 2576 # Input masks only need to be retrieved if they are needed for `call` 2577 # or `compute_mask`. 2578 input_masks = None 2579 implicit_mask = False 2580 elif self._call_arg_was_passed('mask', args, kwargs): 2581 input_masks = self._get_call_arg_value('mask', args, kwargs) 2582 implicit_mask = False 2583 else: 2584 input_masks = [getattr(t, '_keras_mask', None) for t in input_list] 2585 if all(mask is None for mask in input_masks): 2586 input_masks = None 2587 implicit_mask = False 2588 else: 2589 # Only do expensive `nest` op when masking is actually being used. 2590 input_masks = nest.pack_sequence_as(inputs, input_masks) 2591 implicit_mask = True 2592 return input_masks, implicit_mask 2593 2594 def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): 2595 # Performance optimization: do no work in most common case. 2596 if not args and not kwargs: 2597 return False 2598 2599 if arg_name in kwargs: 2600 return True 2601 call_fn_args = self._call_fn_args 2602 if not inputs_in_args: 2603 # Ignore `inputs` arg. 2604 call_fn_args = call_fn_args[1:] 2605 return arg_name in dict(zip(call_fn_args, args)) 2606 2607 def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): 2608 if arg_name in kwargs: 2609 return kwargs[arg_name] 2610 call_fn_args = self._call_fn_args 2611 if not inputs_in_args: 2612 # Ignore `inputs` arg. 2613 call_fn_args = call_fn_args[1:] 2614 args_dict = dict(zip(call_fn_args, args)) 2615 return args_dict[arg_name] 2616 2617 def _set_call_arg_value( 2618 self, arg_name, new_value, args, 2619 kwargs, inputs_in_args=False, pop_kwarg_if_none=False): 2620 arg_pos = self._call_fn_arg_positions.get(arg_name, None) 2621 if arg_pos is not None: 2622 if not inputs_in_args: 2623 # Ignore `inputs` arg. 2624 arg_pos = arg_pos - 1 2625 if len(args) > arg_pos: 2626 args = list(args) 2627 args[arg_pos] = new_value 2628 return tuple(args), kwargs 2629 if new_value is None and pop_kwarg_if_none: 2630 kwargs.pop(arg_name, None) 2631 else: 2632 kwargs[arg_name] = new_value 2633 return args, kwargs 2634 2635 def _set_connectivity_metadata(self, args, kwargs, outputs): 2636 # If the layer returns tensors from its inputs unmodified, 2637 # we copy them to avoid loss of KerasHistory metadata. 2638 flat_outputs = nest.flatten(outputs) 2639 flat_inputs = nest.flatten((args, kwargs)) 2640 input_ids_set = {id(i) for i in flat_inputs} 2641 outputs_copy = [] 2642 for x in flat_outputs: 2643 if id(x) in input_ids_set: 2644 with backend.name_scope(self.name): 2645 x = array_ops.identity(x) 2646 outputs_copy.append(x) 2647 outputs = nest.pack_sequence_as(outputs, outputs_copy) 2648 2649 # Create node, Node wires itself to inbound and outbound layers. 2650 # The Node constructor actually updates this layer's self._inbound_nodes, 2651 # sets _keras_history on the outputs, and adds itself to the 2652 # `_outbound_nodes` of the layers that produced the inputs to this 2653 # layer call. 2654 node_module.Node(self, call_args=args, call_kwargs=kwargs, outputs=outputs) 2655 return outputs 2656 2657 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 2658 """Private utility to retrieves an attribute (e.g. inputs) from a node. 2659 2660 This is used to implement the methods: 2661 - get_input_shape_at 2662 - get_output_shape_at 2663 - get_input_at 2664 etc... 2665 2666 Args: 2667 node_index: Integer index of the node from which 2668 to retrieve the attribute. 2669 attr: Exact node attribute name. 2670 attr_name: Human-readable attribute name, for error messages. 2671 2672 Returns: 2673 The layer's attribute `attr` at the node of index `node_index`. 2674 2675 Raises: 2676 RuntimeError: If the layer has no inbound nodes, or if called in Eager 2677 mode. 2678 ValueError: If the index provided does not match any node. 2679 """ 2680 if not self._inbound_nodes: 2681 raise RuntimeError('The layer has never been called ' 2682 'and thus has no defined ' + attr_name + '.') 2683 if not len(self._inbound_nodes) > node_index: 2684 raise ValueError('Asked to get ' + attr_name + ' at node ' + 2685 str(node_index) + ', but the layer has only ' + 2686 str(len(self._inbound_nodes)) + ' inbound nodes.') 2687 values = getattr(self._inbound_nodes[node_index], attr) 2688 if isinstance(values, list) and len(values) == 1: 2689 return values[0] 2690 else: 2691 return values 2692 2693 def _maybe_build(self, inputs): 2694 # Check input assumptions set before layer building, e.g. input rank. 2695 if not self.built: 2696 input_spec.assert_input_compatibility( 2697 self.input_spec, inputs, self.name) 2698 input_list = nest.flatten(inputs) 2699 if input_list and self._dtype_policy.compute_dtype is None: 2700 try: 2701 dtype = input_list[0].dtype.base_dtype.name 2702 except AttributeError: 2703 pass 2704 else: 2705 self._set_dtype_policy(policy.Policy(dtype)) 2706 input_shapes = None 2707 # Converts Tensors / CompositeTensors to TensorShapes. 2708 if all(hasattr(x, 'shape') for x in input_list): 2709 input_shapes = tf_utils.get_shapes(inputs) 2710 else: 2711 # Converts input shape to TensorShapes. 2712 try: 2713 input_shapes = tf_utils.convert_shapes(inputs, to_tuples=False) 2714 except ValueError: 2715 pass 2716 # Only call `build` if the user has manually overridden the build method. 2717 if not hasattr(self.build, '_is_default'): 2718 # Any setup work performed only once should happen in an `init_scope` 2719 # to avoid creating symbolic Tensors that will later pollute any eager 2720 # operations. 2721 with tf_utils.maybe_init_scope(self): 2722 self.build(input_shapes) # pylint:disable=not-callable 2723 # We must set also ensure that the layer is marked as built, and the build 2724 # shape is stored since user defined build functions may not be calling 2725 # `super.build()` 2726 Layer.build(self, input_shapes) 2727 2728 # Optionally load weight values specified at layer instantiation. 2729 if self._initial_weights is not None: 2730 if ops.executing_eagerly_outside_functions(): 2731 with ops.init_scope(): 2732 # Using `init_scope` since we want variable assignment in 2733 # `set_weights` to be treated like variable initialization. 2734 self.set_weights(self._initial_weights) 2735 else: 2736 self.set_weights(self._initial_weights) 2737 self._initial_weights = None 2738 2739 def _symbolic_call(self, inputs): 2740 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 2741 output_shapes = self.compute_output_shape(input_shapes) 2742 # Convert to TensorShape so that nest.map_structure will not map into 2743 # individual dim of the shape. 2744 output_shapes = tf_utils.convert_shapes(output_shapes, to_tuples=False) 2745 2746 def _make_placeholder_like(shape): 2747 ph = backend.placeholder(shape=shape, dtype=self.dtype) 2748 ph._keras_mask = None 2749 return ph 2750 return nest.map_structure(_make_placeholder_like, output_shapes) 2751 2752 def _get_trainable_state(self): 2753 """Get the `trainable` state of each sublayer. 2754 2755 Returns: 2756 A dict mapping all sublayers to their `trainable` value. 2757 """ 2758 trainable_state = weakref.WeakKeyDictionary() 2759 for layer in self._flatten_layers(): 2760 trainable_state[layer] = layer.trainable 2761 return trainable_state 2762 2763 def _set_trainable_state(self, trainable_state): 2764 """Set `trainable` state for each sublayer.""" 2765 for layer in self._flatten_layers(): 2766 if layer in trainable_state: 2767 layer.trainable = trainable_state[layer] 2768 2769 @property 2770 def _obj_reference_counts(self): 2771 """A dictionary counting the number of attributes referencing an object.""" 2772 self._maybe_create_attribute('_obj_reference_counts_dict', 2773 object_identity.ObjectIdentityDictionary()) 2774 return self._obj_reference_counts_dict 2775 2776 @trackable.no_automatic_dependency_tracking 2777 def _maybe_create_attribute(self, name, default_value): 2778 """Create the attribute with the default value if it hasn't been created. 2779 2780 This is useful for fields that is used for tracking purpose, 2781 _trainable_weights, or _layers. Note that user could create a layer subclass 2782 and assign an internal field before invoking the Layer.__init__(), the 2783 __setattr__() need to create the tracking fields and __init__() need to not 2784 override them. 2785 2786 Args: 2787 name: String, the name of the attribute. 2788 default_value: Object, the default value of the attribute. 2789 """ 2790 if not hasattr(self, name): 2791 self.__setattr__(name, default_value) 2792 2793 def __delattr__(self, name): 2794 # For any super.__delattr__() call, we will directly use the implementation 2795 # in Trackable and skip the behavior in AutoTrackable. The Layer was 2796 # originally use Trackable as base class, the change of using Module as base 2797 # class forced us to have AutoTrackable in the class hierarchy. Skipping 2798 # the __delattr__ and __setattr__ in AutoTrackable will keep the status quo. 2799 existing_value = getattr(self, name, None) 2800 2801 # If this value is replacing an existing object assigned to an attribute, we 2802 # should clean it out to avoid leaking memory. First we check if there are 2803 # other attributes referencing it. 2804 reference_counts = self._obj_reference_counts 2805 if existing_value not in reference_counts: 2806 super(tracking.AutoTrackable, self).__delattr__(name) 2807 return 2808 2809 reference_count = reference_counts[existing_value] 2810 if reference_count > 1: 2811 # There are other remaining references. We can't remove this object from 2812 # _layers etc. 2813 reference_counts[existing_value] = reference_count - 1 2814 super(tracking.AutoTrackable, self).__delattr__(name) 2815 return 2816 else: 2817 # This is the last remaining reference. 2818 del reference_counts[existing_value] 2819 2820 super(tracking.AutoTrackable, self).__delattr__(name) 2821 2822 if (isinstance(existing_value, Layer) 2823 or base_layer_utils.has_weights(existing_value)): 2824 super(tracking.AutoTrackable, self).__setattr__( 2825 '_self_tracked_trackables', 2826 [l for l in self._self_tracked_trackables if l is not existing_value]) 2827 if isinstance(existing_value, tf_variables.Variable): 2828 super(tracking.AutoTrackable, self).__setattr__( 2829 '_trainable_weights', 2830 [w for w in self._trainable_weights if w is not existing_value]) 2831 super(tracking.AutoTrackable, self).__setattr__( 2832 '_non_trainable_weights', 2833 [w for w in self._non_trainable_weights if w is not existing_value]) 2834 2835 def __setattr__(self, name, value): 2836 if (name == '_self_setattr_tracking' or 2837 not getattr(self, '_self_setattr_tracking', True) or 2838 # Exclude @property.setters from tracking 2839 hasattr(self.__class__, name)): 2840 try: 2841 super(tracking.AutoTrackable, self).__setattr__(name, value) 2842 except AttributeError: 2843 raise AttributeError( 2844 ('Can\'t set the attribute "{}", likely because it conflicts with ' 2845 'an existing read-only @property of the object. Please choose a ' 2846 'different name.').format(name)) 2847 return 2848 2849 # Wraps data structures in `Trackable`, unwraps `NoDependency` objects. 2850 value = data_structures.sticky_attribute_assignment( 2851 trackable=self, value=value, name=name) 2852 2853 reference_counts = self._obj_reference_counts 2854 reference_counts[value] = reference_counts.get(value, 0) + 1 2855 2856 # Clean out the old attribute, which clears _layers and _trainable_weights 2857 # if necessary. 2858 try: 2859 self.__delattr__(name) 2860 except AttributeError: 2861 pass 2862 2863 # Keep track of metric instance created in subclassed layer. 2864 for val in nest.flatten(value): 2865 if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'): 2866 self._metrics.append(val) 2867 2868 # Append value to self._self_tracked_trackables if relevant 2869 if (getattr(self, '_auto_track_sub_layers', True) and 2870 (isinstance(value, module.Module) or 2871 base_layer_utils.has_weights(value))): 2872 self._maybe_create_attribute('_self_tracked_trackables', []) 2873 # We need to check object identity to avoid de-duplicating empty 2874 # container types which compare equal. 2875 if not any((layer is value for layer in self._self_tracked_trackables)): 2876 self._self_tracked_trackables.append(value) 2877 if hasattr(value, '_use_resource_variables'): 2878 # Legacy layers (V1 tf.layers) must always use 2879 # resource variables. 2880 value._use_resource_variables = True 2881 2882 # Append value to list of trainable / non-trainable weights if relevant 2883 # TODO(b/125122625): This won't pick up on any variables added to a 2884 # list/dict after creation. 2885 for val in nest.flatten(value, expand_composites=True): 2886 if not isinstance(val, tf_variables.Variable): 2887 continue 2888 2889 # Users may add extra weights/variables 2890 # simply by assigning them to attributes (invalid for graph networks) 2891 self._maybe_create_attribute('_trainable_weights', []) 2892 self._maybe_create_attribute('_non_trainable_weights', []) 2893 if val.trainable: 2894 if any(val is w for w in self._trainable_weights): 2895 continue 2896 self._trainable_weights.append(val) 2897 else: 2898 if any(val is w for w in self._non_trainable_weights): 2899 continue 2900 self._non_trainable_weights.append(val) 2901 2902 backend.track_variable(val) 2903 2904 # Skip the auto trackable from tf.Module to keep status quo. See the comment 2905 # at __delattr__. 2906 super(tracking.AutoTrackable, self).__setattr__(name, value) 2907 2908 def _gather_children_attribute(self, attribute): 2909 assert attribute in { 2910 'variables', 'trainable_variables', 'non_trainable_variables' 2911 } 2912 if hasattr(self, '_self_tracked_trackables'): 2913 nested_layers = self._flatten_modules(include_self=False, recursive=False) 2914 return list( 2915 itertools.chain.from_iterable( 2916 getattr(layer, attribute) for layer in nested_layers)) 2917 return [] 2918 2919 def _flatten_layers(self, recursive=True, include_self=True): 2920 for m in self._flatten_modules( 2921 recursive=recursive, include_self=include_self): 2922 if isinstance(m, Layer): 2923 yield m 2924 2925 def _flatten_modules(self, recursive=True, include_self=True): 2926 """Flattens `tf.Module` instances (excluding `Metrics`). 2927 2928 Args: 2929 recursive: Whether to recursively flatten through submodules. 2930 include_self: Whether to include this `Layer` instance. 2931 2932 Yields: 2933 `tf.Module` instance tracked by this `Layer`. 2934 """ 2935 if include_self: 2936 yield self 2937 2938 # Only instantiate set and deque if needed. 2939 trackables = getattr(self, '_self_tracked_trackables', None) 2940 if trackables: 2941 seen_object_ids = set() 2942 deque = collections.deque(trackables) 2943 while deque: 2944 trackable_obj = deque.popleft() 2945 trackable_id = id(trackable_obj) 2946 if trackable_id in seen_object_ids: 2947 continue 2948 seen_object_ids.add(trackable_id) 2949 2950 # Metrics are not considered part of the Layer's topology. 2951 if (isinstance(trackable_obj, module.Module) and 2952 not isinstance(trackable_obj, metrics_mod.Metric)): 2953 yield trackable_obj 2954 # Introspect recursively through sublayers. 2955 if recursive: 2956 subtrackables = getattr(trackable_obj, '_self_tracked_trackables', 2957 None) 2958 if subtrackables: 2959 deque.extendleft(reversed(subtrackables)) 2960 elif isinstance(trackable_obj, data_structures.TrackableDataStructure): 2961 # Data structures are introspected even with `recursive=False`. 2962 tracked_values = trackable_obj._values 2963 if tracked_values: 2964 deque.extendleft(reversed(tracked_values)) 2965 2966 # This is a hack so that the is_layer (within 2967 # training/trackable/layer_utils.py) check doesn't get the weights attr. 2968 # TODO(b/110718070): Remove when fixed. 2969 def _is_layer(self): 2970 return True 2971 2972 def _init_call_fn_args(self): 2973 # Clear cached call function arguments. 2974 self.__class__._call_full_argspec.fget.cache.pop(self, None) 2975 self.__class__._call_fn_args.fget.cache.pop(self, None) 2976 self.__class__._call_accepts_kwargs.fget.cache.pop(self, None) 2977 2978 call_fn_args = self._call_fn_args 2979 self._expects_training_arg = ('training' in call_fn_args or 2980 self._call_accepts_kwargs) 2981 # The default training arg will be any (non-None) default specified in the 2982 # method signature, or None if no value is specified. 2983 self._default_training_arg = self._call_fn_arg_defaults.get( 2984 'training') 2985 self._expects_mask_arg = ('mask' in call_fn_args or 2986 self._call_accepts_kwargs) 2987 2988 @property 2989 @layer_utils.cached_per_instance 2990 def _call_full_argspec(self): 2991 # Argspec inspection is expensive and the call spec is used often, so it 2992 # makes sense to cache the result. 2993 return tf_inspect.getfullargspec(self.call) 2994 2995 @property 2996 @layer_utils.cached_per_instance 2997 def _call_fn_args(self): 2998 all_args = self._call_full_argspec.args 2999 # Scrub `self` that appears if a decorator was applied. 3000 if all_args and all_args[0] == 'self': 3001 return all_args[1:] 3002 return all_args 3003 3004 @property 3005 @layer_utils.cached_per_instance 3006 def _call_fn_arg_defaults(self): 3007 call_fn_args = self._call_fn_args 3008 call_fn_defaults = self._call_full_argspec.defaults or [] 3009 defaults = dict() 3010 3011 # The call arg defaults are an n-tuple of the last n elements of the args 3012 # list. (n = # of elements that have a default argument) 3013 for i in range(-1 * len(call_fn_defaults), 0): 3014 defaults[call_fn_args[i]] = call_fn_defaults[i] 3015 return defaults 3016 3017 @property 3018 @layer_utils.cached_per_instance 3019 def _call_fn_arg_positions(self): 3020 call_fn_arg_positions = dict() 3021 for pos, arg in enumerate(self._call_fn_args): 3022 call_fn_arg_positions[arg] = pos 3023 return call_fn_arg_positions 3024 3025 @property 3026 @layer_utils.cached_per_instance 3027 def _call_accepts_kwargs(self): 3028 return self._call_full_argspec.varkw is not None 3029 3030 @property 3031 def _eager_losses(self): 3032 # A list of loss values containing activity regularizers and losses 3033 # manually added through `add_loss` during eager execution. It is cleared 3034 # after every batch. 3035 # Because we plan on eventually allowing a same model instance to be trained 3036 # in eager mode or graph mode alternatively, we need to keep track of 3037 # eager losses and symbolic losses via separate attributes. 3038 if not hasattr(self._thread_local, '_eager_losses'): 3039 self._thread_local._eager_losses = [] 3040 return self._thread_local._eager_losses 3041 3042 @_eager_losses.setter 3043 def _eager_losses(self, losses): 3044 self._thread_local._eager_losses = losses 3045 3046 def _dedup_weights(self, weights): 3047 """Dedupe weights while maintaining order as much as possible.""" 3048 output, seen_ids = [], set() 3049 for w in weights: 3050 if id(w) not in seen_ids: 3051 output.append(w) 3052 # Track the Variable's identity to avoid __eq__ issues. 3053 seen_ids.add(id(w)) 3054 3055 return output 3056 3057 def _split_out_first_arg(self, args, kwargs): 3058 # Grab the argument corresponding to the first argument in the 3059 # layer's `call` method spec. This will either be the first positional 3060 # argument, or it will be provided as a keyword argument. 3061 if args: 3062 inputs = args[0] 3063 args = args[1:] 3064 elif self._call_fn_args[0] in kwargs: 3065 kwargs = copy.copy(kwargs) 3066 inputs = kwargs.pop(self._call_fn_args[0]) 3067 else: 3068 raise ValueError( 3069 'The first argument to `Layer.call` must always be passed.') 3070 return inputs, args, kwargs 3071 3072 # SavedModel properties. Please see keras/saving/saved_model for details. 3073 3074 @trackable.no_automatic_dependency_tracking 3075 def _set_save_spec(self, inputs): 3076 if self._saved_model_inputs_spec is not None: 3077 return # Already set. 3078 3079 self._saved_model_inputs_spec = nest.map_structure(tf_utils.get_tensor_spec, 3080 inputs) 3081 3082 def _get_save_spec(self, dynamic_batch=True): 3083 if self._saved_model_inputs_spec is None: 3084 return None 3085 3086 return nest.map_structure( 3087 lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch), 3088 self._saved_model_inputs_spec) 3089 3090 @property 3091 def _trackable_saved_model_saver(self): 3092 return layer_serialization.LayerSavedModelSaver(self) 3093 3094 @property 3095 def _object_identifier(self): 3096 return self._trackable_saved_model_saver.object_identifier 3097 3098 @property 3099 def _tracking_metadata(self): 3100 return self._trackable_saved_model_saver.tracking_metadata 3101 3102 def _list_extra_dependencies_for_serialization(self, serialization_cache): 3103 return (self._trackable_saved_model_saver 3104 .list_extra_dependencies_for_serialization(serialization_cache)) 3105 3106 def _list_functions_for_serialization(self, serialization_cache): 3107 return (self._trackable_saved_model_saver 3108 .list_functions_for_serialization(serialization_cache)) 3109 3110 @property 3111 def _use_input_spec_as_call_signature(self): 3112 # Whether input spec can be used as the call signature when tracing the 3113 # Layer for SavedModel. By default, this is set to `True` for layers 3114 # exported from the Keras library, because the layers more rigidly define 3115 # the `input_specs` property (many custom layers only set the `ndims`) 3116 return get_canonical_name_for_symbol(type(self)) is not None 3117 3118 def __getstate__(self): 3119 # Override to support `copy.deepcopy` and pickling. 3120 # Thread-local objects cannot be copied in Python 3, so pop these. 3121 # Thread-local objects are used to cache losses in MirroredStrategy, and 3122 # so shouldn't be copied. 3123 state = self.__dict__.copy() 3124 state.pop('_thread_local', None) 3125 state.pop('_metrics_lock', None) 3126 return state 3127 3128 def __setstate__(self, state): 3129 state['_thread_local'] = threading.local() 3130 state['_metrics_lock'] = threading.Lock() 3131 # Bypass Trackable logic as `__dict__` already contains this info. 3132 object.__setattr__(self, '__dict__', state) 3133 3134 3135class TensorFlowOpLayer(Layer): 3136 """Wraps a TensorFlow Operation in a Layer. 3137 3138 This class is used internally by the Functional API. When a user 3139 uses a raw TensorFlow Operation on symbolic tensors originating 3140 from an `Input` Layer, the resultant operation will be wrapped 3141 with this Layer object in order to make the operation compatible 3142 with the Keras API. 3143 3144 This Layer will create a new, identical operation (except for inputs 3145 and outputs) every time it is called. If `run_eagerly` is `True`, 3146 the op creation and calculation will happen inside an Eager function. 3147 3148 Instances of this Layer are created when `autolambda` is called, which 3149 is whenever a Layer's `__call__` encounters symbolic inputs that do 3150 not have Keras metadata, or when a Network's `__init__` encounters 3151 outputs that do not have Keras metadata. 3152 3153 Attributes: 3154 node_def: String, the serialized NodeDef of the Op this layer will wrap. 3155 name: String, the name of the Layer. 3156 constants: Dict of NumPy arrays, the values of any Tensors needed for this 3157 Operation that do not originate from a Keras `Input` Layer. Since all 3158 placeholders must come from Keras `Input` Layers, these Tensors must be 3159 treated as constant in the Functional API. 3160 trainable: Bool, whether this Layer is trainable. Currently Variables are 3161 not supported, and so this parameter has no effect. 3162 dtype: The default dtype of this Layer. Inherited from `Layer` and has no 3163 effect on this class, however is used in `get_config`. 3164 """ 3165 3166 @trackable.no_automatic_dependency_tracking 3167 def __init__(self, 3168 node_def, 3169 name, 3170 constants=None, 3171 trainable=True, 3172 dtype=None): 3173 # Pass autocast=False, as if inputs are cast, input types might not match 3174 # Operation type. 3175 super(TensorFlowOpLayer, self).__init__( 3176 name=_TF_OP_LAYER_NAME_PREFIX + name, trainable=trainable, dtype=dtype, 3177 autocast=False) 3178 if isinstance(node_def, dict): 3179 self.node_def = json_format.ParseDict(node_def, node_def_pb2.NodeDef()) 3180 else: 3181 if not isinstance(node_def, bytes): 3182 node_def = node_def.encode('utf-8') 3183 self.node_def = node_def_pb2.NodeDef.FromString(node_def) 3184 # JSON serialization stringifies keys which are integer input indices. 3185 self.constants = ({ 3186 int(index): constant for index, constant in constants.items() 3187 } if constants is not None else {}) 3188 # Layer uses original op unless it is called on new inputs. 3189 # This means `built` is not set in `__call__`. 3190 self.built = True 3191 3192 # Do not individually trace TensorflowOpLayers in the SavedModel. 3193 self._must_restore_from_config = True 3194 3195 def call(self, inputs): 3196 if context.executing_eagerly(): 3197 return self._defun_call(inputs) 3198 return self._make_op(inputs) 3199 3200 def _make_node_def(self, graph): 3201 node_def = node_def_pb2.NodeDef() 3202 node_def.CopyFrom(self.node_def) 3203 # Used in TPUReplicateContext to indicate whether this node has been cloned 3204 # and to not add TPU attributes. 3205 node_def.attr['_cloned'].b = True 3206 node_def.name = graph.unique_name(node_def.name) 3207 return node_def 3208 3209 def _make_op(self, inputs): 3210 inputs = nest.flatten(inputs) 3211 graph = inputs[0].graph 3212 node_def = self._make_node_def(graph) 3213 with graph.as_default(): 3214 for index, constant in self.constants.items(): 3215 # Recreate constant in graph to add distribution context. 3216 value = tensor_util.constant_value(constant) 3217 if value is not None: 3218 constant = constant_op.constant(value, name=node_def.input[index]) 3219 inputs.insert(index, constant) 3220 c_op = ops._create_c_op(graph, node_def, inputs, control_inputs=[]) 3221 op = graph._create_op_from_tf_operation(c_op) 3222 op._control_flow_post_processing() 3223 3224 # Record the gradient because custom-made ops don't go through the 3225 # code-gen'd eager call path 3226 op_type = compat.as_str(op.op_def.name) 3227 attr_names = [compat.as_str(attr.name) for attr in op.op_def.attr] 3228 attrs = [] 3229 for attr_name in attr_names: 3230 attrs.append(attr_name) 3231 attrs.append(op.get_attr(attr_name)) 3232 attrs = tuple(attrs) 3233 execute.record_gradient(op_type, op.inputs, attrs, op.outputs) 3234 3235 if len(op.outputs) == 1: 3236 return op.outputs[0] 3237 return op.outputs 3238 3239 @def_function.function 3240 def _defun_call(self, inputs): 3241 """Wraps the op creation method in an Eager function for `run_eagerly`.""" 3242 return self._make_op(inputs) 3243 3244 def get_config(self): 3245 config = super(TensorFlowOpLayer, self).get_config() 3246 config.update({ 3247 # `__init__` prefixes the name. Revert to the constructor argument. 3248 'name': config['name'][len(_TF_OP_LAYER_NAME_PREFIX):], 3249 'node_def': json_format.MessageToDict(self.node_def), 3250 'constants': { 3251 i: backend.get_value(c) for i, c in self.constants.items() 3252 } 3253 }) 3254 return config 3255 3256 3257class AddLoss(Layer): 3258 """Adds its inputs as a loss. 3259 3260 Attributes: 3261 unconditional: Whether or not the loss should be conditioned on the inputs. 3262 """ 3263 3264 def __init__(self, unconditional, **kwargs): 3265 # Pass autocast=False, as there is no reason to cast loss to a different 3266 # dtype. 3267 kwargs['autocast'] = False 3268 super(AddLoss, self).__init__(**kwargs) 3269 self.unconditional = unconditional 3270 3271 def call(self, inputs): 3272 self.add_loss(inputs, inputs=(not self.unconditional)) 3273 return inputs 3274 3275 def get_config(self): 3276 config = super(AddLoss, self).get_config() 3277 config.update({'unconditional': self.unconditional}) 3278 return config 3279 3280 3281class AddMetric(Layer): 3282 """Adds its inputs as a metric. 3283 3284 Attributes: 3285 aggregation: 'mean' or None. How the inputs should be aggregated. 3286 metric_name: The name to use for this metric. 3287 """ 3288 3289 def __init__(self, aggregation=None, metric_name=None, **kwargs): 3290 super(AddMetric, self).__init__(**kwargs) 3291 self.aggregation = aggregation 3292 self.metric_name = metric_name 3293 3294 def call(self, inputs): 3295 self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name) 3296 return inputs 3297 3298 def get_config(self): 3299 config = super(AddMetric, self).get_config() 3300 config.update({ 3301 'aggregation': self.aggregation, 3302 'metric_name': self.metric_name 3303 }) 3304 return config 3305 3306 3307def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument 3308 """Check the arguments to see if we are constructing a functional model.""" 3309 if keras_tensor.keras_tensors_enabled(): 3310 # We are constructing a functional model if any of the inputs 3311 # are KerasTensors 3312 return any( 3313 isinstance(tensor, keras_tensor.KerasTensor) 3314 for tensor in nest.flatten([inputs, args, kwargs])) 3315 else: 3316 if context.executing_eagerly(): 3317 all_inputs_symbolic = all( 3318 tf_utils.is_symbolic_tensor(t) for t in input_list) 3319 if (base_layer_utils.is_subclassed(layer) and 3320 any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten( 3321 [inputs, args, kwargs])) and not all_inputs_symbolic): 3322 raise ValueError('It appears you are trying to construct a ' 3323 'functional model, but not all of the inputs in ' 3324 'the first positional argument of your layer call ' 3325 'are symbolic tensors. ' 3326 '(Input objects, or the output of another layer) ' 3327 'Functional models cannot correctly track custom ' 3328 'layers unless all values in the first call argument ' 3329 'are symbolic.') 3330 return all_inputs_symbolic 3331 else: 3332 return (base_layer_utils.is_in_keras_graph() or 3333 all(hasattr(t, '_keras_history') for t in input_list)) 3334 3335 3336def _convert_numpy_or_python_types(x): 3337 if isinstance(x, (np_arrays.ndarray, np.ndarray, float, int)): 3338 return ops.convert_to_tensor_v2_with_dispatch(x) 3339 return x 3340 3341 3342# Avoid breaking users who directly import this symbol from this file. 3343# TODO(fchollet): remove this. 3344InputSpec = input_spec.InputSpec # pylint:disable=invalid-name 3345