1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the base Layer class, from which all layers inherit.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import inspect # Necessary supplement to tf_inspect to deal with variadic args. 23import itertools 24 25import numpy as np 26from six.moves import zip # pylint: disable=redefined-builtin 27 28from tensorflow.core.framework import node_def_pb2 29from tensorflow.python.distribute import values as distribute_values 30from tensorflow.python.eager import context 31from tensorflow.python.eager import function 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import func_graph 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.keras import backend 37from tensorflow.python.keras import constraints 38from tensorflow.python.keras import initializers 39from tensorflow.python.keras import regularizers 40from tensorflow.python.keras.engine import base_layer_utils 41from tensorflow.python.keras.engine import input_spec 42from tensorflow.python.keras.mixed_precision.experimental import autocast_variable 43from tensorflow.python.keras.mixed_precision.experimental import policy 44from tensorflow.python.keras.utils import generic_utils 45from tensorflow.python.keras.utils import tf_utils 46# A module that only depends on `keras.layers` import these from here. 47from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import 48from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import 49from tensorflow.python.ops import array_ops 50from tensorflow.python.ops import math_ops 51from tensorflow.python.ops import resource_variable_ops 52from tensorflow.python.ops import variables as tf_variables 53from tensorflow.python.training.tracking import base as trackable 54from tensorflow.python.training.tracking import data_structures 55from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils 56from tensorflow.python.training.tracking import object_identity 57from tensorflow.python.util import function_utils 58from tensorflow.python.util import nest 59from tensorflow.python.util import tf_decorator 60from tensorflow.python.util import tf_inspect 61from tensorflow.python.util.tf_export import keras_export 62from tensorflow.tools.docs import doc_controls 63 64 65@keras_export('keras.layers.Layer') 66class Layer(trackable.Trackable): 67 """Base layer class. 68 69 This is the class from which all layers inherit. 70 71 A layer is a class implementing common neural networks operations, such 72 as convolution, batch norm, etc. These operations require managing weights, 73 losses, updates, and inter-layer connectivity. 74 75 Users will just instantiate a layer and then treat it as a callable. 76 77 We recommend that descendants of `Layer` implement the following methods: 78 79 * `__init__()`: Save configuration in member variables 80 * `build()`: Called once from `__call__`, when we know the shapes of inputs 81 and `dtype`. Should have the calls to `add_weight()`, and then 82 call the super's `build()` (which sets `self.built = True`, which is 83 nice in case the user wants to call `build()` manually before the 84 first `__call__`). 85 * `call()`: Called in `__call__` after making sure `build()` has been called 86 once. Should actually perform the logic of applying the layer to the 87 input tensors (which should be passed in as the first argument). 88 89 Arguments: 90 trainable: Boolean, whether the layer's variables should be trainable. 91 name: String name of the layer. 92 dtype: Default dtype of the layer's weights (default of `None` means use the 93 type of the first input). 94 dynamic: Set this to `True` if your layer should only be run eagerly, and 95 should not be used to generate a static computation graph. 96 This would be the case for a Tree-RNN or a recursive network, 97 for example, or generally for any layer that manipulates tensors 98 using Python control flow. If `False`, we assume that the layer can 99 safely be used to generate a static computation graph. 100 101 Read-only properties: 102 name: The name of the layer (string). 103 dtype: Default dtype of the layer's weights (default of `None` means use the 104 type of the first input). 105 updates: List of update ops of this layer. 106 losses: List of losses added by this layer. 107 trainable_weights: List of variables to be included in backprop. 108 non_trainable_weights: List of variables that should not be 109 included in backprop. 110 weights: The concatenation of the lists trainable_weights and 111 non_trainable_weights (in this order). 112 113 Mutable properties: 114 trainable: Whether the layer should be trained (boolean). 115 input_spec: Optional (list of) `InputSpec` object(s) specifying the 116 constraints on inputs that can be accepted by the layer. 117 """ 118 119 @trackable.no_automatic_dependency_tracking 120 def __init__(self, trainable=True, name=None, dtype=None, dynamic=False, 121 **kwargs): 122 # These properties should be set by the user via keyword arguments. 123 # note that 'dtype', 'input_shape' and 'batch_input_shape' 124 # are only applicable to input layers: do not pass these keywords 125 # to non-input layers. 126 allowed_kwargs = { 127 'input_shape', 128 'batch_input_shape', 129 'batch_size', 130 'weights', 131 'activity_regularizer', 132 } 133 # Validate optional keyword arguments. 134 for kwarg in kwargs: 135 if kwarg not in allowed_kwargs: 136 raise TypeError('Keyword argument not understood:', kwarg) 137 138 # Mutable properties 139 # Indicates whether the layer's weights are updated during training 140 # and whether the layer's updates are run during training 141 self.trainable = trainable 142 # A stateful layer is a layer whose updates are run during inference too, 143 # for instance stateful RNNs. 144 self.stateful = False 145 # Indicates whether `build` needs to be called upon layer call, to create 146 # the layer's weights. 147 self.built = False 148 # Provides information about which inputs are compatible with the layer. 149 self.input_spec = None 150 self.supports_masking = False 151 152 self._init_set_name(name) 153 self._activity_regularizer = kwargs.pop('activity_regularizer', None) 154 if not hasattr(self, '_trainable_weights'): 155 self._trainable_weights = [] 156 if not hasattr(self, '_non_trainable_weights'): 157 self._non_trainable_weights = [] 158 self._updates = [] 159 # A list of zero-argument lambdas which return Tensors, used for variable 160 # regularizers. 161 self._callable_losses = [] 162 # A list of symbolic Tensors containing activity regularizers and losses 163 # manually added through `add_loss` in graph-building mode. 164 self._losses = [] 165 # A list of loss values containing activity regularizers and losses 166 # manually added through `add_loss` during eager execution. It is cleared 167 # after every batch. 168 # Because we plan on eventually allowing a same model instance to be trained 169 # in eager mode or graph mode alternatively, we need to keep track of 170 # eager losses and symbolic losses via separate attributes. 171 self._eager_losses = [] 172 # A list of metric instances corresponding to the symbolic metric tensors 173 # added using the `add_metric` API. 174 self._metrics = [] 175 # TODO(psv): Remove this property. 176 # A dictionary that maps metric names to metric result tensors. The results 177 # are the running averages of metric values over an epoch. 178 self._metrics_tensors = {} 179 180 self._set_dtype_and_policy(dtype) 181 182 self._call_fn_args = function_utils.fn_args(self.call) 183 self._compute_previous_mask = ('mask' in self._call_fn_args or 184 hasattr(self, 'compute_mask')) 185 self._call_convention = (base_layer_utils 186 .CallConvention.EXPLICIT_INPUTS_ARGUMENT) 187 if not hasattr(self, '_layers'): 188 self._layers = [] # Dependencies tracked via attribute assignment. 189 190 # These lists will be filled via successive calls 191 # to self._add_inbound_node(). 192 self._inbound_nodes = [] 193 self._outbound_nodes = [] 194 195 call_argspec = tf_inspect.getfullargspec(self.call) 196 if 'training' in call_argspec.args: 197 self._expects_training_arg = True 198 else: 199 self._expects_training_arg = False 200 201 # Whether the `call` method can be used to build a TF graph without issues. 202 self._dynamic = dynamic 203 204 # Manage input shape information if passed. 205 if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: 206 # In this case we will later create an input layer 207 # to insert before the current layer 208 if 'batch_input_shape' in kwargs: 209 batch_input_shape = tuple(kwargs['batch_input_shape']) 210 elif 'input_shape' in kwargs: 211 if 'batch_size' in kwargs: 212 batch_size = kwargs['batch_size'] 213 else: 214 batch_size = None 215 batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) 216 self._batch_input_shape = batch_input_shape 217 218 # Manage initial weight values if passed. 219 if 'weights' in kwargs: 220 self._initial_weights = kwargs['weights'] 221 else: 222 self._initial_weights = None 223 224 # This flag is used to keep track of whether symbolic tensors are added to 225 # the model outside of the call context. This is required for disabling 226 # `run_eagerly` on compile. 227 # TODO(b/124303407): Remove this flag after we add support for the use case. 228 self._contains_symbolic_tensors = False 229 230 def build(self, input_shape): 231 """Creates the variables of the layer (optional, for subclass implementers). 232 233 This is a method that implementers of subclasses of `Layer` or `Model` 234 can override if they need a state-creation step in-between 235 layer instantiation and layer call. 236 237 This is typically used to create the weights of `Layer` subclasses. 238 239 Arguments: 240 input_shape: Instance of `TensorShape`, or list of instances of 241 `TensorShape` if the layer expects a list of inputs 242 (one instance per input). 243 """ 244 self.built = True 245 246 @doc_controls.for_subclass_implementers 247 def call(self, inputs, **kwargs): # pylint: disable=unused-argument 248 """This is where the layer's logic lives. 249 250 Arguments: 251 inputs: Input tensor, or list/tuple of input tensors. 252 **kwargs: Additional keyword arguments. 253 254 Returns: 255 A tensor or list/tuple of tensors. 256 """ 257 return inputs 258 259 @doc_controls.for_subclass_implementers 260 def add_weight(self, 261 name=None, 262 shape=None, 263 dtype=None, 264 initializer=None, 265 regularizer=None, 266 trainable=None, 267 constraint=None, 268 partitioner=None, 269 use_resource=None, 270 synchronization=tf_variables.VariableSynchronization.AUTO, 271 aggregation=tf_variables.VariableAggregation.NONE, 272 **kwargs): 273 """Adds a new variable to the layer. 274 275 Arguments: 276 name: Variable name. 277 shape: Variable shape. Defaults to scalar if unspecified. 278 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 279 initializer: initializer instance (callable). 280 regularizer: regularizer instance (callable). 281 trainable: whether the variable should be part of the layer's 282 "trainable_variables" (e.g. variables, biases) 283 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 284 Note, if the current variable scope is marked as non-trainable 285 then this parameter is ignored and any added variables are also 286 marked as non-trainable. `trainable` defaults to `True` unless 287 `synchronization` is set to `ON_READ`. 288 constraint: constraint instance (callable). 289 partitioner: Partitioner to be passed to the `Trackable` API. 290 use_resource: Whether to use `ResourceVariable`. 291 synchronization: Indicates when a distributed a variable will be 292 aggregated. Accepted values are constants defined in the class 293 `tf.VariableSynchronization`. By default the synchronization is set to 294 `AUTO` and the current `DistributionStrategy` chooses 295 when to synchronize. If `synchronization` is set to `ON_READ`, 296 `trainable` must not be set to `True`. 297 aggregation: Indicates how a distributed variable will be aggregated. 298 Accepted values are constants defined in the class 299 `tf.VariableAggregation`. 300 **kwargs: Additional keyword arguments. Accepted values are `getter` and 301 `collections`. 302 303 Returns: 304 The created variable. Usually either a `Variable` or `ResourceVariable` 305 instance. If `partitioner` is not `None`, a `PartitionedVariable` 306 instance is returned. 307 308 Raises: 309 RuntimeError: If called with partioned variable regularization and 310 eager execution is enabled. 311 ValueError: When giving unsupported dtype and no initializer or when 312 trainable has been set to True with synchronization set as `ON_READ`. 313 """ 314 if shape is None: 315 shape = () 316 # Validate optional keyword arguments. 317 for kwarg in kwargs: 318 if kwarg not in ['getter', 'collections', 'experimental_autocast']: 319 raise TypeError('Unknown keyword argument:', kwarg) 320 getter = kwargs.pop('getter', None) 321 collections = kwargs.pop('collections', None) 322 # 'experimental_autocast' can be set to False by the caller to indicate an 323 # AutoCastVariable should never be created. 324 autocast = kwargs.pop('experimental_autocast', True) 325 326 if dtype is None: 327 dtype = self.dtype or backend.floatx() 328 dtype = dtypes.as_dtype(dtype) 329 if self._dtype is None: 330 self._dtype = dtype.base_dtype.name 331 initializer = initializers.get(initializer) 332 regularizer = regularizers.get(regularizer) 333 constraint = constraints.get(constraint) 334 335 if synchronization == tf_variables.VariableSynchronization.ON_READ: 336 if trainable: 337 raise ValueError( 338 'Synchronization value can be set to ' 339 'VariableSynchronization.ON_READ only for non-trainable variables. ' 340 'You have specified trainable=True and ' 341 'synchronization=VariableSynchronization.ON_READ.') 342 else: 343 # Set trainable to be false when variable is to be synced on read. 344 trainable = False 345 elif trainable is None: 346 trainable = True 347 348 # Initialize variable when no initializer provided 349 if initializer is None: 350 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 351 if dtype.is_floating: 352 initializer = initializers.glorot_uniform() 353 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 354 # If dtype is DT_BOOL, provide a default value `FALSE` 355 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 356 initializer = initializers.zeros() 357 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 358 else: 359 raise ValueError('An initializer for variable %s of type %s is required' 360 ' for layer %s' % (name, dtype.base_dtype, self.name)) 361 362 variable = self._add_variable_with_custom_getter( 363 name=name, 364 shape=shape, 365 # TODO(allenl): a `make_variable` equivalent should be added as a 366 # `Trackable` method. 367 getter=getter or base_layer_utils.make_variable, 368 # Manage errors in Layer rather than Trackable. 369 overwrite=True, 370 initializer=initializer, 371 dtype=dtype, 372 constraint=constraint, 373 trainable=trainable and self.trainable, 374 partitioner=partitioner, 375 use_resource=use_resource, 376 collections=collections, 377 synchronization=synchronization, 378 aggregation=aggregation) 379 backend.track_variable(variable) 380 381 if autocast and self._mixed_precision_policy.should_cast_variables: 382 if isinstance(variable, distribute_values.DistributedVariable): 383 variable = autocast_variable.AutoCastDistributedVariable(variable) 384 else: 385 variable = autocast_variable.AutoCastVariable(variable) 386 387 if regularizer is not None: 388 # TODO(fchollet): in the future, this should be handled at the 389 # level of variable creation, and weight regularization losses 390 # should be variable attributes. 391 name_in_scope = variable.name[:variable.name.find(':')] 392 self._handle_weight_regularization(name_in_scope, 393 variable, 394 regularizer) 395 if trainable: 396 self._trainable_weights.append(variable) 397 else: 398 self._non_trainable_weights.append(variable) 399 return variable 400 401 def get_config(self): 402 """Returns the config of the layer. 403 404 A layer config is a Python dictionary (serializable) 405 containing the configuration of a layer. 406 The same layer can be reinstantiated later 407 (without its trained weights) from this configuration. 408 409 The config of a layer does not include connectivity 410 information, nor the layer class name. These are handled 411 by `Network` (one layer of abstraction above). 412 413 Returns: 414 Python dictionary. 415 """ 416 config = {'name': self.name, 'trainable': self.trainable} 417 if hasattr(self, '_batch_input_shape'): 418 config['batch_input_shape'] = self._batch_input_shape 419 if hasattr(self, 'dtype'): 420 config['dtype'] = self.dtype 421 # TODO(reedwm): Handle serializing self._mixed_precision_policy. 422 return config 423 424 @classmethod 425 def from_config(cls, config): 426 """Creates a layer from its config. 427 428 This method is the reverse of `get_config`, 429 capable of instantiating the same layer from the config 430 dictionary. It does not handle layer connectivity 431 (handled by Network), nor weights (handled by `set_weights`). 432 433 Arguments: 434 config: A Python dictionary, typically the 435 output of get_config. 436 437 Returns: 438 A layer instance. 439 """ 440 return cls(**config) 441 442 def compute_output_shape(self, input_shape): 443 """Computes the output shape of the layer. 444 445 Assumes that the layer will be built 446 to match that input shape provided. 447 448 Arguments: 449 input_shape: Shape tuple (tuple of integers) 450 or list of shape tuples (one per output tensor of the layer). 451 Shape tuples can include None for free dimensions, 452 instead of an integer. 453 454 Returns: 455 An input shape tuple. 456 """ 457 if context.executing_eagerly(): 458 # In this case we build the model first in order to do shape inference. 459 # This is acceptable because the framework only calls 460 # `compute_output_shape` on shape values that the layer would later be 461 # built for. It would however cause issues in case a user attempts to 462 # use `compute_output_shape` manually (these users will have to 463 # implement `compute_output_shape` themselves). 464 self.build(input_shape) 465 with context.graph_mode(): 466 graph = func_graph.FuncGraph('graph') 467 with graph.as_default(): 468 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 469 inputs = nest.map_structure( 470 base_layer_utils.generate_placeholders_from_shape, input_shape) 471 try: 472 if self._expects_training_arg: 473 outputs = self(inputs, training=False) 474 else: 475 outputs = self(inputs) 476 except TypeError: 477 raise NotImplementedError('We could not automatically infer ' 478 'the static shape of the layer\'s output.' 479 ' Please implement the ' 480 '`compute_output_shape` method on your ' 481 'layer (%s).' % self.__class__.__name__) 482 return nest.map_structure(lambda t: t.shape, outputs) 483 raise NotImplementedError 484 485 def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument 486 """Computes an output mask tensor. 487 488 Arguments: 489 inputs: Tensor or list of tensors. 490 mask: Tensor or list of tensors. 491 492 Returns: 493 None or a tensor (or list of tensors, 494 one per output tensor of the layer). 495 """ 496 if not self.supports_masking: 497 if any(m is not None for m in nest.flatten(mask)): 498 raise TypeError('Layer ' + self.name + ' does not support masking, ' 499 'but was passed an input_mask: ' + str(mask)) 500 # masking not explicitly supported: return None as mask. 501 return None 502 # if masking is explicitly supported, by default 503 # carry over the input mask 504 return mask 505 506 def __call__(self, inputs, *args, **kwargs): 507 """Wraps `call`, applying pre- and post-processing steps. 508 509 Arguments: 510 inputs: input tensor(s). 511 *args: additional positional arguments to be passed to `self.call`. 512 **kwargs: additional keyword arguments to be passed to `self.call`. 513 514 Returns: 515 Output tensor(s). 516 517 Note: 518 - The following optional keyword arguments are reserved for specific uses: 519 * `training`: Boolean scalar tensor of Python boolean indicating 520 whether the `call` is meant for training or inference. 521 * `mask`: Boolean input mask. 522 - If the layer's `call` method takes a `mask` argument (as some Keras 523 layers do), its default value will be set to the mask generated 524 for `inputs` by the previous layer (if `input` did come from 525 a layer that generated a corresponding mask, i.e. if it came from 526 a Keras layer with masking support. 527 528 Raises: 529 ValueError: if the layer's `call` method returns None (an invalid value). 530 """ 531 input_list = nest.flatten(inputs) 532 # Accept NumPy inputs by converting to Tensors. 533 if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): 534 # Don't call `ops.convert_to_tensor` on all `inputs` because 535 # `SparseTensors` can't be converted to `Tensor`. 536 def _convert_non_tensor(x): 537 if isinstance(x, (np.ndarray, float, int)): 538 return ops.convert_to_tensor(x) 539 return x 540 541 inputs = nest.map_structure(_convert_non_tensor, inputs) 542 input_list = nest.flatten(inputs) 543 544 # We will attempt to build a TF graph if & only if all inputs are symbolic. 545 # This is always the case in graph mode. It can also be the case in eager 546 # mode when all inputs can be traced back to `keras.Input()` (when building 547 # models using the functional API). 548 build_graph = tf_utils.are_all_symbolic_tensors(input_list) 549 550 if build_graph: 551 # Only create Keras history if at least one tensor originates from a 552 # `keras.Input`. Otherwise this Layer may be being used outside the Keras 553 # framework. 554 if base_layer_utils.needs_keras_history(inputs): 555 base_layer_utils.create_keras_history(inputs) 556 557 # Handle Keras mask propagation from previous layer to current layer. 558 previous_mask = None 559 if (not hasattr(self, '_compute_previous_mask') or 560 self._compute_previous_mask): 561 previous_mask = base_layer_utils.collect_previous_mask(inputs) 562 if not hasattr(self, '_call_fn_args'): 563 self._call_fn_args = function_utils.fn_args(self.call) 564 if ('mask' in self._call_fn_args and 'mask' not in kwargs and 565 not generic_utils.is_all_none(previous_mask)): 566 # The previous layer generated a mask, and mask was not explicitly 567 # pass to __call__, hence we set previous_mask as the default value. 568 kwargs['mask'] = previous_mask 569 570 # Clear eager losses on top level model call. 571 # We are clearing the losses only on the top level model call and not on 572 # every layer/mode call because layer/model may be reused. 573 if (context.executing_eagerly() and 574 not base_layer_utils.is_in_call_context()): 575 self._clear_losses() 576 577 with base_layer_utils.call_context(): 578 # Check input assumptions set after layer building, e.g. input shape. 579 if build_graph: 580 # Symbolic execution on symbolic tensors. We will attempt to build 581 # the corresponding TF subgraph inside `backend.get_graph()` 582 input_spec.assert_input_compatibility(self.input_spec, inputs, 583 self.name) 584 graph = backend.get_graph() 585 with graph.as_default(), ops.name_scope(self._name_scope()): 586 # Build layer if applicable (if the `build` method has been 587 # overridden). 588 self._maybe_build(inputs) 589 # Explicitly pass the learning phase placeholder to `call` if 590 # the `training` argument was left unspecified by the user. 591 # This behavior is restricted to the managed Keras FuncGraph. 592 learning_phase_passed_by_framework = False 593 if (self._expects_training_arg and 594 not base_layer_utils.training_arg_passed_to_call( 595 tf_inspect.getfullargspec(self.call), args, kwargs) and 596 getattr(graph, 'name', None) == 'keras_graph'): 597 learning_phase_passed_by_framework = True 598 kwargs['training'] = backend.learning_phase() 599 if not self.dynamic: 600 try: 601 with base_layer_utils.autocast_context_manager( 602 input_list, 603 self._mixed_precision_policy.should_cast_variables), ( 604 base_layer_utils.AutoAddUpdates(self, 605 inputs)) as auto_updater: 606 outputs = self.call(inputs, *args, **kwargs) 607 auto_updater.set_outputs(outputs) 608 609 except TypeError as e: 610 messages = ('`tf.Tensor` as a Python `bool` is not allowed', 611 'Tensor objects are only iterable when eager') 612 exception_str = str(e) 613 for msg in messages: 614 if msg in exception_str: 615 raise TypeError('You are attempting to use Python control ' 616 'flow in a layer that was not declared to be ' 617 'dynamic. Pass `dynamic=True` to the class ' 618 'constructor.\nEncountered error:\n"""\n' + 619 exception_str + '\n"""') 620 raise 621 else: 622 # We will use static shape inference to return symbolic tensors 623 # matching the specifications of the layer outputs. 624 # Since `self.dynamic` is True, we will never attempt to 625 # run the underlying TF graph (which is disconnected). 626 # TODO(fchollet): consider py_func as an alternative, which 627 # would enable us to run the underlying graph if needed. 628 outputs = self._symbolic_call(inputs) 629 630 if outputs is None: 631 raise ValueError('A layer\'s `call` method should return a ' 632 'Tensor or a list of Tensors, not None ' 633 '(layer: ' + self.name + ').') 634 if base_layer_utils.have_all_keras_metadata(inputs): 635 if learning_phase_passed_by_framework: 636 kwargs.pop('training') 637 inputs, outputs = self._set_connectivity_metadata_( 638 inputs, outputs, args, kwargs) 639 self._handle_activity_regularization(inputs, outputs) 640 self._set_mask_metadata(inputs, outputs, previous_mask) 641 if hasattr(self, '_set_inputs') and not self.inputs: 642 # Subclassed network: explicitly set metadata normally set by 643 # a call to self._set_inputs(). 644 # TODO(b/120997007): This should be done in Eager as well, but 645 # causes garbage collection issues because of the placeholders 646 # created on the default Keras graph. 647 self._set_inputs(inputs, outputs) 648 else: 649 # Eager execution on data tensors. 650 with ops.name_scope(self._name_scope()): 651 self._maybe_build(inputs) 652 with base_layer_utils.autocast_context_manager( 653 input_list, self._mixed_precision_policy.should_cast_variables): 654 outputs = self.call(inputs, *args, **kwargs) 655 self._handle_activity_regularization(inputs, outputs) 656 self._set_mask_metadata(inputs, outputs, previous_mask) 657 658 if not context.executing_eagerly(): 659 # Optionally load weight values specified at layer instantiation. 660 # TODO(fchollet): consider enabling this with eager execution too. 661 if (hasattr(self, '_initial_weights') and 662 self._initial_weights is not None): 663 self.set_weights(self._initial_weights) 664 del self._initial_weights 665 return outputs 666 667 @property 668 def dtype(self): 669 return self._dtype 670 671 @property 672 def name(self): 673 return self._name 674 675 @property 676 def dynamic(self): 677 return self._dynamic 678 679 @property 680 def activity_regularizer(self): 681 """Optional regularizer function for the output of this layer.""" 682 return self._activity_regularizer 683 684 @activity_regularizer.setter 685 def activity_regularizer(self, regularizer): 686 """Optional regularizer function for the output of this layer.""" 687 self._activity_regularizer = regularizer 688 689 @property 690 def trainable_weights(self): 691 if self.trainable: 692 nested = self._gather_children_attribute('trainable_weights') 693 return self._trainable_weights + nested 694 else: 695 return [] 696 697 @property 698 def non_trainable_weights(self): 699 if self.trainable: 700 nested = self._gather_children_attribute('non_trainable_weights') 701 return self._non_trainable_weights + nested 702 else: 703 nested = self._gather_children_attribute('weights') 704 return self._trainable_weights + self._non_trainable_weights + nested 705 706 @property 707 def weights(self): 708 """Returns the list of all layer variables/weights. 709 710 Returns: 711 A list of variables. 712 """ 713 return self.trainable_weights + self.non_trainable_weights 714 715 @property 716 def updates(self): 717 return self._get_unfiltered_updates(check_trainable=True) 718 719 @property 720 def losses(self): 721 """Losses which are associated with this `Layer`. 722 723 Variable regularization tensors are created when this property is accessed, 724 so it is eager safe: accessing `losses` under a `tf.GradientTape` will 725 propagate gradients back to the corresponding variables. 726 727 Returns: 728 A list of tensors. 729 """ 730 collected_losses = [] 731 732 # If any eager losses are present, we assume the model to be part of an 733 # eager training loop (either a custom one or the one used when 734 # `run_eagerly=True`), and so we always return just the eager losses in that 735 # case. 736 if self._eager_losses: 737 collected_losses.extend(self._eager_losses) 738 else: 739 collected_losses.extend(self._losses) 740 for regularizer in self._callable_losses: 741 loss_tensor = regularizer() 742 if loss_tensor is not None: 743 collected_losses.append(loss_tensor) 744 return collected_losses + self._gather_children_attribute('losses') 745 746 @doc_controls.for_subclass_implementers 747 def add_loss(self, losses, inputs=None): 748 """Add loss tensor(s), potentially dependent on layer inputs. 749 750 Some losses (for instance, activity regularization losses) may be dependent 751 on the inputs passed when calling a layer. Hence, when reusing the same 752 layer on different inputs `a` and `b`, some entries in `layer.losses` may 753 be dependent on `a` and some on `b`. This method automatically keeps track 754 of dependencies. 755 756 The `get_losses_for` method allows to retrieve the losses relevant to a 757 specific set of inputs. 758 759 Note that `add_loss` is not supported when executing eagerly. Instead, 760 variable regularizers may be added through `add_variable`. Activity 761 regularization is not supported directly (but such losses may be returned 762 from `Layer.call()`). 763 764 Arguments: 765 losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses 766 may also be zero-argument callables which create a loss tensor. 767 Other types of input are ignored. 768 inputs: Ignored when executing eagerly. If anything other than None is 769 passed, it signals the losses are conditional on some of the layer's 770 inputs, and thus they should only be run where these inputs are 771 available. This is the case for activity regularization losses, for 772 instance. If `None` is passed, the losses are assumed 773 to be unconditional, and will apply across all dataflows of the layer 774 (e.g. weight regularization losses). 775 """ 776 losses = generic_utils.to_list(losses) 777 778 def _tag_unconditional(loss): 779 if callable(loss): 780 loss = loss() 781 if loss is None: 782 return None # Will be filtered out when computing the .losses property 783 if not tensor_util.is_tensor(loss): 784 loss = ops.convert_to_tensor(loss, dtype=backend.floatx()) 785 loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access 786 return loss 787 788 for loss in losses: 789 if callable(loss): 790 self._callable_losses.append( 791 functools.partial(_tag_unconditional, loss)) 792 else: 793 if not tensor_util.is_tensor(loss): 794 # Ignoring constant values as this does not affect the gradients. 795 return 796 if tf_utils.is_symbolic_tensor(loss): 797 if not base_layer_utils.is_in_call_context(): 798 self._contains_symbolic_tensors = True 799 self._losses.append(_tag_unconditional(loss)) 800 else: 801 self._eager_losses.append(_tag_unconditional(loss)) 802 803 @trackable.no_automatic_dependency_tracking 804 def _clear_losses(self): 805 """Used every step in eager to reset losses.""" 806 self._eager_losses = [] 807 if hasattr(self, '_layers'): 808 for layer in trackable_layer_utils.filter_empty_layer_containers( 809 self._layers): 810 layer._clear_losses() 811 812 @doc_controls.for_subclass_implementers 813 def add_metric(self, value, aggregation=None, name=None): 814 """Adds metric tensor to the layer. 815 816 Args: 817 value: Metric tensor. 818 aggregation: Sample-wise metric reduction function. If `aggregation=None`, 819 it indicates that the metric tensor provided has been aggregated 820 already. eg, `bin_acc = BinaryAccuracy(name='acc')` followed by 821 `model.add_metric(bin_acc(y_true, y_pred))`. If aggregation='mean', the 822 given metric tensor will be sample-wise reduced using `mean` function. 823 eg, `model.add_metric(tf.reduce_sum(outputs), name='output_mean', 824 aggregation='mean')`. 825 name: String metric name. 826 827 Raises: 828 ValueError: If `aggregation` is anything other than None or `mean`. 829 """ 830 if aggregation is not None and aggregation != 'mean': 831 raise ValueError( 832 'We currently support only `mean` sample-wise metric aggregation. ' 833 'You provided aggregation=`%s`' % aggregation) 834 835 is_symbolic = tf_utils.is_symbolic_tensor(value) 836 if name is None and (not is_symbolic or not hasattr(value, '_metric_obj')): 837 # Eg. `self.add_metric(math_ops.reduce_sum(x), aggregation='mean')` 838 # In eager mode, we use metric name to lookup a metric. Without a name, 839 # a new Mean metric wrapper will be created on every model/layer call. 840 # So, we raise an error when no name is provided. 841 # We will do the same for symbolic mode for consistency although a name 842 # will be generated if no name is provided. 843 844 # We will not raise this error in the foll use case for the sake of 845 # consistency as name in provided in the metric constructor. 846 # mean = metrics.Mean(name='my_metric') 847 # model.add_metric(mean(outputs)) 848 raise ValueError('Please provide a name for your metric like ' 849 '`self.add_metric(tf.reduce_sum(inputs), ' 850 'name=\'mean_activation\', aggregation=\'mean\')`') 851 852 if is_symbolic: 853 with backend.get_graph().as_default(): 854 self._symbolic_add_metric(value, aggregation, name) 855 else: 856 self._eager_add_metric(value, aggregation, name) 857 858 @doc_controls.for_subclass_implementers 859 def add_update(self, updates, inputs=None): 860 """Add update op(s), potentially dependent on layer inputs. 861 862 Weight updates (for instance, the updates of the moving mean and variance 863 in a BatchNormalization layer) may be dependent on the inputs passed 864 when calling a layer. Hence, when reusing the same layer on 865 different inputs `a` and `b`, some entries in `layer.updates` may be 866 dependent on `a` and some on `b`. This method automatically keeps track 867 of dependencies. 868 869 The `get_updates_for` method allows to retrieve the updates relevant to a 870 specific set of inputs. 871 872 This call is ignored when eager execution is enabled (in that case, variable 873 updates are run on the fly and thus do not need to be tracked for later 874 execution). 875 876 Arguments: 877 updates: Update op, or list/tuple of update ops. 878 inputs: If anything other than None is passed, it signals the updates 879 are conditional on some of the layer's inputs, 880 and thus they should only be run where these inputs are available. 881 This is the case for BatchNormalization updates, for instance. 882 If None, the updates will be taken into account unconditionally, 883 and you are responsible for making sure that any dependency they might 884 have is available at runtime. 885 A step counter might fall into this category. 886 """ 887 if context.executing_eagerly(): 888 return # Updates already applied when in eager mode. 889 890 def process_update(x): 891 if isinstance(x, ops.Operation): 892 return x 893 elif hasattr(x, 'op'): 894 return x.op 895 else: 896 return ops.convert_to_tensor(x) 897 898 updates = generic_utils.to_list(updates) 899 updates = [process_update(x) for x in updates] 900 self._updates += updates 901 if inputs is None: 902 for u in updates: 903 u._unconditional_update = True # pylint: disable=protected-access 904 else: 905 for u in updates: 906 u._unconditional_update = False # pylint: disable=protected-access 907 908 def set_weights(self, weights): 909 """Sets the weights of the layer, from Numpy arrays. 910 911 Arguments: 912 weights: a list of Numpy arrays. The number 913 of arrays and their shape must match 914 number of the dimensions of the weights 915 of the layer (i.e. it should match the 916 output of `get_weights`). 917 918 Raises: 919 ValueError: If the provided weights list does not match the 920 layer's specifications. 921 """ 922 params = self.weights 923 if len(params) != len(weights): 924 raise ValueError('You called `set_weights(weights)` on layer "' + 925 self.name + '" with a weight list of length ' + 926 str(len(weights)) + ', but the layer was expecting ' + 927 str(len(params)) + ' weights. Provided weights: ' + 928 str(weights)[:50] + '...') 929 if not params: 930 return 931 weight_value_tuples = [] 932 param_values = backend.batch_get_value(params) 933 for pv, p, w in zip(param_values, params, weights): 934 if pv.shape != w.shape: 935 raise ValueError('Layer weight shape ' + str(pv.shape) + 936 ' not compatible with ' 937 'provided weight shape ' + str(w.shape)) 938 weight_value_tuples.append((p, w)) 939 backend.batch_set_value(weight_value_tuples) 940 941 def get_weights(self): 942 """Returns the current weights of the layer. 943 944 Returns: 945 Weights values as a list of numpy arrays. 946 """ 947 params = self.weights 948 return backend.batch_get_value(params) 949 950 def get_updates_for(self, inputs): 951 """Retrieves updates relevant to a specific set of inputs. 952 953 Arguments: 954 inputs: Input tensor or list/tuple of input tensors. 955 956 Returns: 957 List of update ops of the layer that depend on `inputs`. 958 959 Raises: 960 RuntimeError: If called in Eager mode. 961 """ 962 # Updates disabled if layer is not trainable and not explicitly stateful. 963 if not self.trainable and not self.stateful: 964 return [] 965 966 if inputs is None: 967 # Requesting unconditional updates. 968 return [ 969 x for x in self._get_unfiltered_updates() if x._unconditional_update # pylint: disable=protected-access 970 ] 971 972 # Requesting input-conditional updates. 973 inputs = nest.flatten(inputs) 974 reachable = tf_utils.get_reachable_from_inputs( 975 inputs, self._get_unfiltered_updates()) 976 return [u for u in self._get_unfiltered_updates() if u in reachable] # pylint: disable=protected-access 977 978 def get_losses_for(self, inputs): 979 """Retrieves losses relevant to a specific set of inputs. 980 981 Arguments: 982 inputs: Input tensor or list/tuple of input tensors. 983 984 Returns: 985 List of loss tensors of the layer that depend on `inputs`. 986 987 Raises: 988 RuntimeError: If called in Eager mode. 989 """ 990 if inputs is None: 991 # Requesting unconditional losses. 992 return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access 993 994 # Requesting input-conditional losses. 995 inputs = nest.flatten(inputs) 996 # Retrieve the set of tensors in the TF graph that depend on `inputs`. 997 # The losses we want to return will be part of this set. 998 # To avoid unnecessary work, we stop the search in case all of 999 # `self.losses` have been retrieved. 1000 reachable = tf_utils.get_reachable_from_inputs(inputs, self.losses) 1001 losses = [] 1002 for loss in self.losses: 1003 if loss in reachable: 1004 losses.append(loss) 1005 return losses 1006 1007 def get_input_mask_at(self, node_index): 1008 """Retrieves the input mask tensor(s) of a layer at a given node. 1009 1010 Arguments: 1011 node_index: Integer, index of the node 1012 from which to retrieve the attribute. 1013 E.g. `node_index=0` will correspond to the 1014 first time the layer was called. 1015 1016 Returns: 1017 A mask tensor 1018 (or list of tensors if the layer has multiple inputs). 1019 """ 1020 inputs = self.get_input_at(node_index) 1021 if isinstance(inputs, list): 1022 return [getattr(x, '_keras_mask', None) for x in inputs] 1023 else: 1024 return getattr(inputs, '_keras_mask', None) 1025 1026 def get_output_mask_at(self, node_index): 1027 """Retrieves the output mask tensor(s) of a layer at a given node. 1028 1029 Arguments: 1030 node_index: Integer, index of the node 1031 from which to retrieve the attribute. 1032 E.g. `node_index=0` will correspond to the 1033 first time the layer was called. 1034 1035 Returns: 1036 A mask tensor 1037 (or list of tensors if the layer has multiple outputs). 1038 """ 1039 output = self.get_output_at(node_index) 1040 if isinstance(output, list): 1041 return [getattr(x, '_keras_mask', None) for x in output] 1042 else: 1043 return getattr(output, '_keras_mask', None) 1044 1045 @property 1046 def input_mask(self): 1047 """Retrieves the input mask tensor(s) of a layer. 1048 1049 Only applicable if the layer has exactly one inbound node, 1050 i.e. if it is connected to one incoming layer. 1051 1052 Returns: 1053 Input mask tensor (potentially None) or list of input 1054 mask tensors. 1055 1056 Raises: 1057 AttributeError: if the layer is connected to 1058 more than one incoming layers. 1059 """ 1060 inputs = self.input 1061 if isinstance(inputs, list): 1062 return [getattr(x, '_keras_mask', None) for x in inputs] 1063 else: 1064 return getattr(inputs, '_keras_mask', None) 1065 1066 @property 1067 def output_mask(self): 1068 """Retrieves the output mask tensor(s) of a layer. 1069 1070 Only applicable if the layer has exactly one inbound node, 1071 i.e. if it is connected to one incoming layer. 1072 1073 Returns: 1074 Output mask tensor (potentially None) or list of output 1075 mask tensors. 1076 1077 Raises: 1078 AttributeError: if the layer is connected to 1079 more than one incoming layers. 1080 """ 1081 output = self.output 1082 if isinstance(output, list): 1083 return [getattr(x, '_keras_mask', None) for x in output] 1084 else: 1085 return getattr(output, '_keras_mask', None) 1086 1087 def get_input_shape_at(self, node_index): 1088 """Retrieves the input shape(s) of a layer at a given node. 1089 1090 Arguments: 1091 node_index: Integer, index of the node 1092 from which to retrieve the attribute. 1093 E.g. `node_index=0` will correspond to the 1094 first time the layer was called. 1095 1096 Returns: 1097 A shape tuple 1098 (or list of shape tuples if the layer has multiple inputs). 1099 1100 Raises: 1101 RuntimeError: If called in Eager mode. 1102 """ 1103 return self._get_node_attribute_at_index(node_index, 'input_shapes', 1104 'input shape') 1105 1106 def get_output_shape_at(self, node_index): 1107 """Retrieves the output shape(s) of a layer at a given node. 1108 1109 Arguments: 1110 node_index: Integer, index of the node 1111 from which to retrieve the attribute. 1112 E.g. `node_index=0` will correspond to the 1113 first time the layer was called. 1114 1115 Returns: 1116 A shape tuple 1117 (or list of shape tuples if the layer has multiple outputs). 1118 1119 Raises: 1120 RuntimeError: If called in Eager mode. 1121 """ 1122 return self._get_node_attribute_at_index(node_index, 'output_shapes', 1123 'output shape') 1124 1125 def get_input_at(self, node_index): 1126 """Retrieves the input tensor(s) of a layer at a given node. 1127 1128 Arguments: 1129 node_index: Integer, index of the node 1130 from which to retrieve the attribute. 1131 E.g. `node_index=0` will correspond to the 1132 first time the layer was called. 1133 1134 Returns: 1135 A tensor (or list of tensors if the layer has multiple inputs). 1136 1137 Raises: 1138 RuntimeError: If called in Eager mode. 1139 """ 1140 return self._get_node_attribute_at_index(node_index, 'input_tensors', 1141 'input') 1142 1143 def get_output_at(self, node_index): 1144 """Retrieves the output tensor(s) of a layer at a given node. 1145 1146 Arguments: 1147 node_index: Integer, index of the node 1148 from which to retrieve the attribute. 1149 E.g. `node_index=0` will correspond to the 1150 first time the layer was called. 1151 1152 Returns: 1153 A tensor (or list of tensors if the layer has multiple outputs). 1154 1155 Raises: 1156 RuntimeError: If called in Eager mode. 1157 """ 1158 return self._get_node_attribute_at_index(node_index, 'output_tensors', 1159 'output') 1160 1161 @property 1162 def input(self): 1163 """Retrieves the input tensor(s) of a layer. 1164 1165 Only applicable if the layer has exactly one input, 1166 i.e. if it is connected to one incoming layer. 1167 1168 Returns: 1169 Input tensor or list of input tensors. 1170 1171 Raises: 1172 AttributeError: if the layer is connected to 1173 more than one incoming layers. 1174 1175 Raises: 1176 RuntimeError: If called in Eager mode. 1177 AttributeError: If no inbound nodes are found. 1178 """ 1179 if not self._inbound_nodes: 1180 raise AttributeError('Layer ' + self.name + 1181 ' is not connected, no input to return.') 1182 return self._get_node_attribute_at_index(0, 'input_tensors', 'input') 1183 1184 @property 1185 def output(self): 1186 """Retrieves the output tensor(s) of a layer. 1187 1188 Only applicable if the layer has exactly one output, 1189 i.e. if it is connected to one incoming layer. 1190 1191 Returns: 1192 Output tensor or list of output tensors. 1193 1194 Raises: 1195 AttributeError: if the layer is connected to more than one incoming 1196 layers. 1197 RuntimeError: if called in Eager mode. 1198 """ 1199 if not self._inbound_nodes: 1200 raise AttributeError('Layer ' + self.name + ' has no inbound nodes.') 1201 return self._get_node_attribute_at_index(0, 'output_tensors', 'output') 1202 1203 @property 1204 def input_shape(self): 1205 """Retrieves the input shape(s) of a layer. 1206 1207 Only applicable if the layer has exactly one input, 1208 i.e. if it is connected to one incoming layer, or if all inputs 1209 have the same shape. 1210 1211 Returns: 1212 Input shape, as an integer shape tuple 1213 (or list of shape tuples, one tuple per input tensor). 1214 1215 Raises: 1216 AttributeError: if the layer has no defined input_shape. 1217 RuntimeError: if called in Eager mode. 1218 """ 1219 if not self._inbound_nodes: 1220 raise AttributeError('The layer has never been called ' 1221 'and thus has no defined input shape.') 1222 all_input_shapes = set( 1223 [str(node.input_shapes) for node in self._inbound_nodes]) 1224 if len(all_input_shapes) == 1: 1225 return self._inbound_nodes[0].input_shapes 1226 else: 1227 raise AttributeError('The layer "' + str(self.name) + 1228 ' has multiple inbound nodes, ' 1229 'with different input shapes. Hence ' 1230 'the notion of "input shape" is ' 1231 'ill-defined for the layer. ' 1232 'Use `get_input_shape_at(node_index)` ' 1233 'instead.') 1234 1235 def count_params(self): 1236 """Count the total number of scalars composing the weights. 1237 1238 Returns: 1239 An integer count. 1240 1241 Raises: 1242 ValueError: if the layer isn't yet built 1243 (in which case its weights aren't yet defined). 1244 """ 1245 if not self.built: 1246 if self.__class__.__name__ == 'Sequential': 1247 self.build() # pylint: disable=no-value-for-parameter 1248 else: 1249 raise ValueError('You tried to call `count_params` on ' + self.name + 1250 ', but the layer isn\'t built. ' 1251 'You can build it manually via: `' + self.name + 1252 '.build(batch_input_shape)`.') 1253 return int(sum(np.prod(w.shape.as_list()) for w in self.weights)) 1254 1255 @property 1256 def output_shape(self): 1257 """Retrieves the output shape(s) of a layer. 1258 1259 Only applicable if the layer has one output, 1260 or if all outputs have the same shape. 1261 1262 Returns: 1263 Output shape, as an integer shape tuple 1264 (or list of shape tuples, one tuple per output tensor). 1265 1266 Raises: 1267 AttributeError: if the layer has no defined output shape. 1268 RuntimeError: if called in Eager mode. 1269 """ 1270 if not self._inbound_nodes: 1271 raise AttributeError('The layer has never been called ' 1272 'and thus has no defined output shape.') 1273 all_output_shapes = set( 1274 [str(node.output_shapes) for node in self._inbound_nodes]) 1275 if len(all_output_shapes) == 1: 1276 return self._inbound_nodes[0].output_shapes 1277 else: 1278 raise AttributeError('The layer "%s"' 1279 ' has multiple inbound nodes, ' 1280 'with different output shapes. Hence ' 1281 'the notion of "output shape" is ' 1282 'ill-defined for the layer. ' 1283 'Use `get_output_shape_at(node_index)` ' 1284 'instead.' % self.name) 1285 1286 @property 1287 @doc_controls.do_not_doc_inheritable 1288 def inbound_nodes(self): 1289 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 1290 return self._inbound_nodes 1291 1292 @property 1293 @doc_controls.do_not_doc_inheritable 1294 def outbound_nodes(self): 1295 """Deprecated, do NOT use! Only for compatibility with external Keras.""" 1296 return self._outbound_nodes 1297 1298 ############################################################################## 1299 # Methods & attributes below are public aliases of other methods. # 1300 ############################################################################## 1301 1302 def apply(self, inputs, *args, **kwargs): 1303 """Apply the layer on a input. 1304 1305 This is an alias of `self.__call__`. 1306 1307 Arguments: 1308 inputs: Input tensor(s). 1309 *args: additional positional arguments to be passed to `self.call`. 1310 **kwargs: additional keyword arguments to be passed to `self.call`. 1311 1312 Returns: 1313 Output tensor(s). 1314 """ 1315 return self.__call__(inputs, *args, **kwargs) 1316 1317 @doc_controls.for_subclass_implementers 1318 def add_variable(self, *args, **kwargs): 1319 """Alias for `add_weight`.""" 1320 return self.add_weight(*args, **kwargs) 1321 1322 @property 1323 def variables(self): 1324 """Returns the list of all layer variables/weights. 1325 1326 Alias of `self.weights`. 1327 1328 Returns: 1329 A list of variables. 1330 """ 1331 return self.weights 1332 1333 @property 1334 def trainable_variables(self): 1335 return self.trainable_weights 1336 1337 @property 1338 def non_trainable_variables(self): 1339 return self.non_trainable_weights 1340 1341 ############################################################################## 1342 # Methods & attributes below are all private and only used by the framework. # 1343 ############################################################################## 1344 1345 def _set_dtype_and_policy(self, dtype): 1346 """Sets self._dtype and self._mixed_precision_policy.""" 1347 if dtype: 1348 if isinstance(dtype, policy.Policy): 1349 self._mixed_precision_policy = dtype 1350 self._dtype = self._mixed_precision_policy.default_variable_dtype 1351 else: 1352 # If a non-policy dtype is passed, no casting should be done. So we use 1353 # the "infer" policy, which does no casting. 1354 self._mixed_precision_policy = policy.Policy('infer') 1355 self._dtype = dtypes.as_dtype(dtype).name 1356 else: 1357 self._mixed_precision_policy = policy.global_policy() 1358 # If the global policy has not been set, it will be an "infer" policy 1359 # without a default variable dtype, and so self._dtype will be None. In 1360 # that case, self._dtype will be set when the layer is built or called. 1361 self._dtype = self._mixed_precision_policy.default_variable_dtype 1362 1363 def _name_scope(self): 1364 return self.name 1365 1366 def _init_set_name(self, name, zero_based=True): 1367 if not name: 1368 self._name = base_layer_utils.unique_layer_name( 1369 generic_utils.to_snake_case(self.__class__.__name__), 1370 zero_based=zero_based) 1371 else: 1372 self._name = name 1373 1374 def _get_existing_metric(self, name=None): 1375 match = [m for m in self._metrics if m.name == name] 1376 if not match: 1377 return 1378 if len(match) > 1: 1379 raise ValueError( 1380 'Please provide different names for the metrics you have added. ' 1381 'We found {} metrics with the name: "{}"'.format(len(match), name)) 1382 return match[0] 1383 1384 def _eager_add_metric(self, value, aggregation=None, name=None): 1385 # If the given metric is available in `metrics` list we just update state 1386 # on it, otherwise we create a new metric instance and 1387 # add it to the `metrics` list. 1388 match = self._get_existing_metric(name) 1389 if match: 1390 match(value) # Update the metric state. 1391 return 1392 else: 1393 # Aggregation will always be set in this use case. If not we will raise 1394 # error on model/layer call in graph function mode when model/layer is 1395 # created. 1396 assert aggregation is not None 1397 metric_obj, _ = base_layer_utils.create_mean_metric(value, name) 1398 self._metrics.append(metric_obj) 1399 1400 def _symbolic_add_metric(self, value, aggregation=None, name=None): 1401 if not base_layer_utils.is_in_call_context(): 1402 self._contains_symbolic_tensors = True 1403 if aggregation is None: 1404 # Iterate over the metrics and check if the given metric exists already. 1405 # This can happen when a metric instance is created in subclassed model 1406 # layer `__init__` and we have tracked that instance already in 1407 # model.__setattr__. 1408 match = self._get_existing_metric(name) 1409 if match: 1410 result_tensor = value 1411 if match.name not in self._metrics_tensors: 1412 self._metrics_tensors[match.name] = result_tensor 1413 return 1414 else: 1415 raise ValueError( 1416 'We currently do not support reusing a metric instance.') 1417 elif hasattr(value, '_metric_obj'): 1418 # We track the instance using the metadata on the result tensor. 1419 result_tensor = value 1420 metric_obj = result_tensor._metric_obj 1421 else: 1422 raise ValueError( 1423 'We do not support adding an aggregated metric result tensor that ' 1424 'is not the output of a `tf.keras.metrics.Metric` metric instance. ' 1425 'Without having access to the metric instance we cannot reset the ' 1426 'state of a metric after every epoch during training. You can ' 1427 'create a `tf.keras.metrics.Metric` instance and pass the result ' 1428 'here or pass an un-aggregated result with `aggregation` parameter ' 1429 'set as `mean`. For example: `self.add_metric(tf.reduce_sum(inputs)' 1430 ', name=\'mean_activation\', aggregation=\'mean\')`') 1431 else: 1432 # If a non-aggregated tensor is given as input (ie. `aggregation` is 1433 # explicitly set to `mean`), we wrap the tensor in `Mean` metric. 1434 metric_obj, result_tensor = base_layer_utils.create_mean_metric( 1435 value, name) 1436 self._metrics.append(metric_obj) 1437 self._metrics_tensors[metric_obj.name] = result_tensor 1438 1439 def _handle_weight_regularization(self, name, variable, regularizer): 1440 """Create lambdas which compute regularization losses.""" 1441 1442 def _loss_for_variable(v): 1443 """Creates a regularization loss `Tensor` for variable `v`.""" 1444 with ops.name_scope(name + '/Regularizer'): 1445 regularization = regularizer(v) 1446 return regularization 1447 1448 if isinstance(variable, tf_variables.PartitionedVariable): 1449 for v in variable: 1450 self.add_loss(functools.partial(_loss_for_variable, v)) 1451 else: 1452 self.add_loss(functools.partial(_loss_for_variable, variable)) 1453 1454 def _handle_activity_regularization(self, inputs, outputs): 1455 # Apply activity regularization. 1456 # Note that it should be applied every time the layer creates a new 1457 # output, since it is output-specific. 1458 if self._activity_regularizer: 1459 output_list = nest.flatten(outputs) 1460 with ops.name_scope('ActivityRegularizer'): 1461 for output in output_list: 1462 activity_loss = self._activity_regularizer(output) 1463 batch_size = math_ops.cast( 1464 array_ops.shape(output)[0], activity_loss.dtype) 1465 # Make activity regularization strength batch-agnostic. 1466 mean_activity_loss = activity_loss / batch_size 1467 self.add_loss(mean_activity_loss, inputs=inputs) 1468 1469 def _set_mask_metadata(self, inputs, outputs, previous_mask): 1470 flat_outputs = nest.flatten(outputs) 1471 mask_already_computed = ( 1472 getattr(self, '_compute_output_and_mask_jointly', False) or 1473 all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) 1474 1475 if not mask_already_computed: 1476 if hasattr(self, 'compute_mask'): 1477 output_masks = self.compute_mask(inputs, previous_mask) 1478 # `compute_mask` can return a single `None` even when a Layer 1479 # has multiple outputs. 1480 if output_masks is None: 1481 flat_masks = [None for _ in flat_outputs] 1482 else: 1483 flat_masks = nest.flatten(output_masks) 1484 else: 1485 flat_masks = [None for _ in flat_outputs] 1486 1487 for output, mask in zip(flat_outputs, flat_masks): 1488 try: 1489 output._keras_mask = mask 1490 except AttributeError: 1491 # C Type such as np.ndarray. 1492 pass 1493 1494 if tf_utils.are_all_symbolic_tensors(flat_outputs): 1495 for output in flat_outputs: 1496 if getattr(output, '_keras_mask', None) is not None: 1497 # Do not track masks for `TensorFlowOpLayer` construction. 1498 output._keras_mask._keras_history_checked = True 1499 1500 def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): 1501 call_convention = getattr( 1502 self, '_call_convention', 1503 base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT) 1504 if args: 1505 if call_convention == (base_layer_utils 1506 .CallConvention.EXPLICIT_INPUTS_ARGUMENT): 1507 raise TypeError( 1508 'This layer ("{}") takes an `inputs` argument in `call()`, ' 1509 'and only the `inputs` argument may be specified as a positional ' 1510 'argument. Pass everything else as a keyword argument ' 1511 '(those arguments will not be tracked ' 1512 'as inputs to the layer).'.format(self.name)) 1513 elif call_convention == (base_layer_utils 1514 .CallConvention.SINGLE_POSITIONAL_ARGUMENT): 1515 raise TypeError( 1516 'This layer ("{}") takes a single positional argument in `call()`,' 1517 ' which is by convention the `inputs` argument, ' 1518 'and only this argument may be specified as a positional argument. ' 1519 'Pass everything else as a keyword argument ' 1520 '(those arguments will not be tracked ' 1521 'as inputs to the layer).'.format(self.name)) 1522 1523 # If the layer returns tensors from its inputs, unmodified, 1524 # we copy them to avoid loss of tensor metadata. 1525 output_ls = nest.flatten(outputs) 1526 inputs_ls = nest.flatten(inputs) 1527 output_ls_copy = [] 1528 for x in output_ls: 1529 if x in inputs_ls: 1530 with ops.name_scope(self.name): 1531 x = array_ops.identity(x) 1532 output_ls_copy.append(x) 1533 outputs = nest.pack_sequence_as(outputs, output_ls_copy) 1534 1535 inputs, kwargs = self._inputs_from_call_args( 1536 call_args=(inputs,) + args, call_kwargs=kwargs) 1537 # Add an inbound node to the layer, so it can keep track of this call. 1538 # This updates the layer history of the output tensor(s). 1539 kwargs.pop('mask', None) # `mask` should not be serialized. 1540 self._add_inbound_node( 1541 input_tensors=inputs, output_tensors=outputs, arguments=kwargs) 1542 return inputs, outputs 1543 1544 def _inputs_from_call_args(self, call_args, call_kwargs): 1545 """Get Layer inputs from __call__ *args and **kwargs. 1546 1547 Args: 1548 call_args: The positional arguments passed to __call__. 1549 call_kwargs: The keyword argument dict passed to __call__. 1550 1551 Returns: 1552 A tuple of (inputs, non_input_kwargs). These may be the same objects as 1553 were passed in (call_args and call_kwargs). 1554 """ 1555 call_convention = getattr( 1556 self, '_call_convention', 1557 base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT) 1558 if (call_convention in ( 1559 base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT, 1560 base_layer_utils.CallConvention.SINGLE_POSITIONAL_ARGUMENT)): 1561 assert len(call_args) == 1 # TypeError raised earlier in __call__. 1562 return call_args[0], call_kwargs 1563 else: 1564 call_arg_spec = tf_inspect.getfullargspec(self.call) 1565 # There is no explicit "inputs" argument expected or provided to 1566 # call(). Arguments which have default values are considered non-inputs, 1567 # and arguments without are considered inputs. 1568 if call_arg_spec.defaults: 1569 if call_arg_spec.varargs is not None: 1570 raise TypeError( 1571 'Layers may not accept both positional arguments and ' 1572 'arguments with default values (unable to determine which ' 1573 'are inputs to the layer). ' 1574 'Issue occurred with layer "%s"' % (self.name)) 1575 keyword_arg_names = set( 1576 call_arg_spec.args[-len(call_arg_spec.defaults):]) 1577 else: 1578 keyword_arg_names = set() 1579 # Training is never an input argument name, to allow signatures like 1580 # call(x, training). 1581 keyword_arg_names.add('training') 1582 _, unwrapped_call = tf_decorator.unwrap(self.call) 1583 bound_args = inspect.getcallargs( 1584 unwrapped_call, *call_args, **call_kwargs) 1585 if call_arg_spec.varkw is not None: 1586 var_kwargs = bound_args.pop(call_arg_spec.varkw) 1587 bound_args.update(var_kwargs) 1588 keyword_arg_names = keyword_arg_names.union(var_kwargs.keys()) 1589 all_args = call_arg_spec.args 1590 if all_args and bound_args[all_args[0]] is self: 1591 # Ignore the 'self' argument of methods 1592 bound_args.pop(call_arg_spec.args[0]) 1593 all_args = all_args[1:] 1594 non_input_arg_values = {} 1595 input_arg_values = [] 1596 remaining_args_are_keyword = False 1597 for argument_name in all_args: 1598 if argument_name in keyword_arg_names: 1599 remaining_args_are_keyword = True 1600 else: 1601 if remaining_args_are_keyword: 1602 raise TypeError( 1603 'Found a positional argument in a layer call after a non-input ' 1604 'argument. All arguments after "training" must be keyword ' 1605 'arguments, and are not tracked as inputs to the layer. ' 1606 'Issue occurred with layer "%s"' % (self.name)) 1607 if remaining_args_are_keyword: 1608 non_input_arg_values[argument_name] = bound_args[argument_name] 1609 else: 1610 input_arg_values.append(bound_args[argument_name]) 1611 if call_arg_spec.varargs is not None: 1612 input_arg_values.extend(bound_args[call_arg_spec.varargs]) 1613 return input_arg_values, non_input_arg_values 1614 1615 def _add_inbound_node(self, 1616 input_tensors, 1617 output_tensors, 1618 arguments=None): 1619 """Internal method to create an inbound node for the layer. 1620 1621 Arguments: 1622 input_tensors: list of input tensors. 1623 output_tensors: list of output tensors. 1624 arguments: dictionary of keyword arguments that were passed to the 1625 `call` method of the layer at the call that created the node. 1626 """ 1627 inbound_layers = nest.map_structure(lambda t: t._keras_history[0], 1628 input_tensors) 1629 node_indices = nest.map_structure(lambda t: t._keras_history[1], 1630 input_tensors) 1631 tensor_indices = nest.map_structure(lambda t: t._keras_history[2], 1632 input_tensors) 1633 1634 # Create node, add it to inbound nodes. 1635 Node( 1636 self, 1637 inbound_layers=inbound_layers, 1638 node_indices=node_indices, 1639 tensor_indices=tensor_indices, 1640 input_tensors=input_tensors, 1641 output_tensors=output_tensors, 1642 arguments=arguments) 1643 1644 # Update tensor history metadata. 1645 # The metadata attribute consists of 1646 # 1) a layer instance 1647 # 2) a node index for the layer 1648 # 3) a tensor index for the node. 1649 # The allows layer reuse (multiple nodes per layer) and multi-output 1650 # or multi-input layers (e.g. a layer can return multiple tensors, 1651 # and each can be sent to a different layer). 1652 for i, tensor in enumerate(nest.flatten(output_tensors)): 1653 tensor._keras_history = (self, len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access 1654 1655 def _get_node_attribute_at_index(self, node_index, attr, attr_name): 1656 """Private utility to retrieves an attribute (e.g. inputs) from a node. 1657 1658 This is used to implement the methods: 1659 - get_input_shape_at 1660 - get_output_shape_at 1661 - get_input_at 1662 etc... 1663 1664 Arguments: 1665 node_index: Integer index of the node from which 1666 to retrieve the attribute. 1667 attr: Exact node attribute name. 1668 attr_name: Human-readable attribute name, for error messages. 1669 1670 Returns: 1671 The layer's attribute `attr` at the node of index `node_index`. 1672 1673 Raises: 1674 RuntimeError: If the layer has no inbound nodes, or if called in Eager 1675 mode. 1676 ValueError: If the index provided does not match any node. 1677 """ 1678 if not self._inbound_nodes: 1679 raise RuntimeError('The layer has never been called ' 1680 'and thus has no defined ' + attr_name + '.') 1681 if not len(self._inbound_nodes) > node_index: 1682 raise ValueError('Asked to get ' + attr_name + ' at node ' + 1683 str(node_index) + ', but the layer has only ' + 1684 str(len(self._inbound_nodes)) + ' inbound nodes.') 1685 values = getattr(self._inbound_nodes[node_index], attr) 1686 if isinstance(values, list) and len(values) == 1: 1687 return values[0] 1688 else: 1689 return values 1690 1691 def _maybe_build(self, inputs): 1692 # Check input assumptions set before layer building, e.g. input rank. 1693 if self.built: 1694 return 1695 1696 input_spec.assert_input_compatibility( 1697 self.input_spec, inputs, self.name) 1698 input_list = nest.flatten(inputs) 1699 if input_list and self._dtype is None: 1700 try: 1701 self._dtype = input_list[0].dtype.base_dtype.name 1702 except AttributeError: 1703 pass 1704 input_shapes = None 1705 if all(hasattr(x, 'shape') for x in input_list): 1706 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 1707 # Only call `build` if the user has manually overridden the build method. 1708 if not hasattr(self.build, '_is_default'): 1709 self.build(input_shapes) 1710 # We must set self.built since user defined build functions are not 1711 # constrained to set self.built. 1712 self.built = True 1713 1714 def _symbolic_call(self, inputs): 1715 input_shapes = nest.map_structure(lambda x: x.shape, inputs) 1716 output_shapes = self.compute_output_shape(input_shapes) 1717 1718 def _make_placeholder_like(shape): 1719 ph = backend.placeholder(shape=shape, dtype=self.dtype) 1720 ph._keras_mask = None 1721 return ph 1722 1723 return nest.map_structure(_make_placeholder_like, output_shapes) 1724 1725 @property 1726 def _obj_reference_counts(self): 1727 """A dictionary counting the number of attributes referencing an object.""" 1728 if not hasattr(self, '_obj_reference_counts_dict'): 1729 super(Layer, self).__setattr__( 1730 '_obj_reference_counts_dict', 1731 object_identity.ObjectIdentityDictionary()) 1732 return self._obj_reference_counts_dict 1733 1734 def __delattr__(self, name): 1735 existing_value = getattr(self, name, None) 1736 1737 # If this value is replacing an existing object assigned to an attribute, we 1738 # should clean it out to avoid leaking memory. First we check if there are 1739 # other attributes referencing it. 1740 reference_counts = self._obj_reference_counts 1741 if existing_value not in reference_counts: 1742 super(Layer, self).__delattr__(name) 1743 return 1744 1745 reference_count = reference_counts[existing_value] 1746 if reference_count > 1: 1747 # There are other remaining references. We can't remove this object from 1748 # _layers etc. 1749 reference_counts[existing_value] = reference_count - 1 1750 super(Layer, self).__delattr__(name) 1751 return 1752 else: 1753 # This is the last remaining reference. 1754 del reference_counts[existing_value] 1755 1756 super(Layer, self).__delattr__(name) 1757 1758 if (isinstance(existing_value, Layer) 1759 or trackable_layer_utils.has_weights(existing_value)): 1760 super(Layer, self).__setattr__( 1761 '_layers', 1762 [l for l in self._layers if l is not existing_value]) 1763 if isinstance(existing_value, tf_variables.Variable): 1764 super(Layer, self).__setattr__( 1765 '_trainable_weights', 1766 [w for w in self._trainable_weights if w is not existing_value]) 1767 super(Layer, self).__setattr__( 1768 '_non_trainable_weights', 1769 [w for w in self._non_trainable_weights if w is not existing_value]) 1770 1771 def __setattr__(self, name, value): 1772 if (not getattr(self, '_setattr_tracking', True) or 1773 getattr(self, '_is_graph_network', False) or 1774 # Exclude @property.setters from tracking 1775 hasattr(self.__class__, name)): 1776 super(Layer, self).__setattr__(name, value) 1777 return 1778 1779 # Keep track of trackable objects, for the needs of `Network.save_weights`. 1780 value = data_structures.sticky_attribute_assignment( 1781 trackable=self, value=value, name=name) 1782 1783 reference_counts = self._obj_reference_counts 1784 reference_counts[value] = reference_counts.get(value, 0) + 1 1785 1786 # Clean out the old attribute, which clears _layers and _trainable_weights 1787 # if necessary. 1788 try: 1789 self.__delattr__(name) 1790 except AttributeError: 1791 pass 1792 1793 # Append value to self._layers if relevant 1794 if (isinstance(value, Layer) or 1795 trackable_layer_utils.has_weights(value)): 1796 # Initialize `_layers` here in case `__init__` has not yet been called. 1797 if not hasattr(self, '_layers'): 1798 super(Layer, self).__setattr__('_layers', []) 1799 # We need to check object identity to avoid de-duplicating empty 1800 # container types which compare equal. 1801 if not any((layer is value for layer in self._layers)): 1802 self._layers.append(value) 1803 if hasattr(value, '_use_resource_variables'): 1804 # Legacy layers (V1 tf.layers) must always use 1805 # resource variables. 1806 value._use_resource_variables = True 1807 1808 # Append value to list of trainable / non-trainable weights if relevant 1809 # TODO(b/125122625): This won't pick up on any variables added to a 1810 # list/dict after creation. 1811 for val in nest.flatten(value): 1812 # TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops 1813 # no longer return True for isinstance Variable checks. 1814 if (isinstance(val, tf_variables.Variable) and 1815 not isinstance(val, resource_variable_ops._UnreadVariable)): # pylint: disable=protected-access 1816 # Users may add extra weights/variables 1817 # simply by assigning them to attributes (invalid for graph networks) 1818 if not hasattr(self, '_trainable_weights'): 1819 super(Layer, self).__setattr__('_trainable_weights', []) 1820 if not hasattr(self, '_non_trainable_weights'): 1821 super(Layer, self).__setattr__('_non_trainable_weights', []) 1822 if val not in self._trainable_weights + self._non_trainable_weights: 1823 if val.trainable: 1824 self._trainable_weights.append(val) 1825 else: 1826 self._non_trainable_weights.append(val) 1827 backend.track_variable(val) 1828 1829 super(Layer, self).__setattr__(name, value) 1830 1831 def _gather_children_attribute(self, attribute): 1832 assert attribute in { 1833 'weights', 'trainable_weights', 'non_trainable_weights', 'updates', 1834 'losses' 1835 } 1836 if hasattr(self, '_layers'): 1837 nested_layers = trackable_layer_utils.filter_empty_layer_containers( 1838 self._layers) 1839 return list( 1840 itertools.chain.from_iterable( 1841 getattr(layer, attribute) for layer in nested_layers)) 1842 return [] 1843 1844 # This is a hack so that the is_layer (within 1845 # training/trackable/layer_utils.py) check doesn't get the weights attr. 1846 # TODO(b/110718070): Remove when fixed. 1847 def _is_layer(self): 1848 return True 1849 1850 def _get_unfiltered_updates(self, check_trainable=True): 1851 if check_trainable and not self.trainable and not self.stateful: 1852 return [] 1853 return self._updates + self._gather_children_attribute('updates') 1854 1855 1856class Node(object): 1857 """A `Node` describes the connectivity between two layers. 1858 1859 Each time a layer is connected to some new input, 1860 a node is added to `layer._inbound_nodes`. 1861 Each time the output of a layer is used by another layer, 1862 a node is added to `layer._outbound_nodes`. 1863 1864 Arguments: 1865 outbound_layer: the layer that takes 1866 `input_tensors` and turns them into `output_tensors` 1867 (the node gets created when the `call` 1868 method of the layer was called). 1869 inbound_layers: a list of layers, the same length as `input_tensors`, 1870 the layers from where `input_tensors` originate. 1871 node_indices: a list of integers, the same length as `inbound_layers`. 1872 `node_indices[i]` is the origin node of `input_tensors[i]` 1873 (necessary since each inbound layer might have several nodes, 1874 e.g. if the layer is being shared with a different data stream). 1875 tensor_indices: a list of integers, 1876 the same length as `inbound_layers`. 1877 `tensor_indices[i]` is the index of `input_tensors[i]` within the 1878 output of the inbound layer 1879 (necessary since each inbound layer might 1880 have multiple tensor outputs, with each one being 1881 independently manipulable). 1882 input_tensors: list of input tensors. 1883 output_tensors: list of output tensors. 1884 arguments: dictionary of keyword arguments that were passed to the 1885 `call` method of the layer at the call that created the node. 1886 1887 `node_indices` and `tensor_indices` are basically fine-grained coordinates 1888 describing the origin of the `input_tensors`. 1889 1890 A node from layer A to layer B is added to: 1891 - A._outbound_nodes 1892 - B._inbound_nodes 1893 """ 1894 1895 def __init__(self, 1896 outbound_layer, 1897 inbound_layers, 1898 node_indices, 1899 tensor_indices, 1900 input_tensors, 1901 output_tensors, 1902 arguments=None): 1903 # Layer instance (NOT a sequence) 1904 if isinstance(outbound_layer, (list, tuple, dict)): 1905 raise ValueError('`outbound_layer` should be a layer instance, ' 1906 'not a list, tuple, or, dict.') 1907 1908 # this is the layer that takes a nested structure of input tensors 1909 # and turns them into a nested structure of output tensors. 1910 # the current node will be added to 1911 # the inbound_nodes of outbound_layer. 1912 self.outbound_layer = outbound_layer 1913 1914 # The following 3 properties describe where 1915 # the input tensors come from: which layers, 1916 # and for each layer, which node and which 1917 # tensor output of each node. 1918 1919 # Nested structure of layer instances. 1920 self.inbound_layers = inbound_layers 1921 # Nested structure of integers, 1:1 mapping with inbound_layers. 1922 self.node_indices = node_indices 1923 # Nested of integers, 1:1 mapping with inbound_layers. 1924 self.tensor_indices = tensor_indices 1925 1926 # Following 2 properties: 1927 # tensor inputs and outputs of outbound_layer. 1928 1929 # Nested structure of tensors. 1:1 mapping with inbound_layers. 1930 self.input_tensors = input_tensors 1931 # Nested structure of tensors, created by outbound_layer.call(). 1932 self.output_tensors = output_tensors 1933 1934 # Following 2 properties: input and output shapes. 1935 1936 # Nested structure of shape tuples, shapes of input_tensors. 1937 self.input_shapes = nest.map_structure(backend.int_shape, input_tensors) 1938 # Nested structure of shape tuples, shapes of output_tensors. 1939 self.output_shapes = nest.map_structure(backend.int_shape, output_tensors) 1940 1941 # Optional keyword arguments to layer's `call`. 1942 self.arguments = arguments 1943 1944 # Add nodes to all layers involved. 1945 for layer in nest.flatten(inbound_layers): 1946 if layer is not None: 1947 # For compatibility with external Keras, we use the deprecated 1948 # accessor here. 1949 layer.outbound_nodes.append(self) 1950 # For compatibility with external Keras, we use the deprecated 1951 # accessor here. 1952 outbound_layer.inbound_nodes.append(self) 1953 1954 def iterate_inbound(self): 1955 """Returns a list of tuples representing the inbound data. 1956 1957 Returns: 1958 List of tuples like: (inbound_layer, node_index, tensor_index, tensor). 1959 """ 1960 return zip( 1961 nest.flatten(self.inbound_layers), nest.flatten(self.node_indices), 1962 nest.flatten(self.tensor_indices), nest.flatten(self.input_tensors)) 1963 1964 def get_config(self): 1965 inbound_names = nest.map_structure( 1966 lambda layer: layer.name if layer else None, self.inbound_layers) 1967 return { 1968 'outbound_layer': self.outbound_layer.name, 1969 'inbound_layers': inbound_names, 1970 'node_indices': self.node_indices, 1971 'tensor_indices': self.tensor_indices 1972 } 1973 1974 1975class TensorFlowOpLayer(Layer): 1976 """Wraps a TensorFlow Operation in a Layer. 1977 1978 This class is used internally by the Functional API. When a user 1979 uses a raw TensorFlow Operation on symbolic tensors originating 1980 from an `Input` Layer, the resultant operation will be wrapped 1981 with this Layer object in order to make the operation compatible 1982 with the Keras API. 1983 1984 This Layer will create a new, identical operation (except for inputs 1985 and outputs) every time it is called. If `run_eagerly` is `True`, 1986 the op creation and calculation will happen inside an Eager function. 1987 1988 Instances of this Layer are created when `autolambda` is called, which 1989 is whenever a Layer's `__call__` encounters symbolic inputs that do 1990 not have Keras metadata, or when a Network's `__init__` encounters 1991 outputs that do not have Keras metadata. 1992 1993 Attributes: 1994 node_def: String, the serialized NodeDef of the Op this layer will wrap. 1995 constants: Dict of NumPy arrays, the values of any Tensors needed for this 1996 Operation that do not originate from a Keras `Input` Layer. Since all 1997 placeholders must come from Keras `Input` Layers, these Tensors must be 1998 treated as constant in the Functional API. 1999 name: String, the name of the Layer. 2000 trainable: Bool, whether this Layer is trainable. Currently Variables are 2001 not supported, and so this parameter has no effect. 2002 dtype: The default dtype of this Layer. Inherited from `Layer` and has no 2003 effect on this class, however is used in `get_config`. 2004 """ 2005 2006 def __init__(self, 2007 node_def, 2008 constants=None, 2009 name=None, 2010 trainable=True, 2011 dtype=None): 2012 super(TensorFlowOpLayer, self).__init__( 2013 name=name, trainable=trainable, dtype=dtype) 2014 self.node_def = node_def_pb2.NodeDef.FromString(node_def) 2015 self.constants = constants or {} 2016 # Layer uses original op unless it is called on new inputs. 2017 # This means `built` is not set in `__call__`. 2018 self.built = True 2019 2020 def call(self, inputs): 2021 if context.executing_eagerly(): 2022 return self._defun_call(inputs) 2023 return self._make_op(inputs) 2024 2025 def _make_op(self, inputs): 2026 inputs = nest.flatten(inputs) 2027 graph = inputs[0].graph 2028 with graph.as_default(): 2029 for index, constant in self.constants.items(): 2030 constant = ops.convert_to_tensor(constant) 2031 inputs.insert(index, constant) 2032 2033 self.node_def.name = graph.unique_name(self.node_def.name) 2034 # Check for case where first input should be a list of Tensors. 2035 if 'N' in self.node_def.attr: 2036 num_tensors = self.node_def.attr['N'].i 2037 inputs = [inputs[:num_tensors]] + inputs[num_tensors:] 2038 c_op = ops._create_c_op(graph, self.node_def, inputs, control_inputs=[]) 2039 op = graph._create_op_from_tf_operation(c_op) 2040 2041 if len(op.outputs) == 1: 2042 return op.outputs[0] 2043 return op.outputs 2044 2045 @function.defun 2046 def _defun_call(self, inputs): 2047 """Wraps the op creation method in an Eager function for `run_eagerly`.""" 2048 return self._make_op(inputs) 2049 2050 def get_config(self): 2051 config = super(TensorFlowOpLayer, self).get_config() 2052 config.update({ 2053 'node_def': self.node_def.SerializeToString(), 2054 'constants': self.constants 2055 }) 2056 return config 2057 2058 2059def default(method): 2060 """Decorates a method to detect overrides in subclasses.""" 2061 method._is_default = True 2062 return method 2063 2064 2065# Avoid breaking users who directly import this symbol from this file. 2066# TODO(fchollet): remove this. 2067InputSpec = input_spec.InputSpec # pylint:disable=invalid-name 2068