1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core Keras layers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import functools 23import operator 24import sys 25import textwrap 26import types as python_types 27import warnings 28 29import numpy as np 30 31from tensorflow.python.eager import backprop 32from tensorflow.python.eager import context 33from tensorflow.python.eager import monitoring 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.keras import activations 39from tensorflow.python.keras import backend as K 40from tensorflow.python.keras import constraints 41from tensorflow.python.keras import initializers 42from tensorflow.python.keras import regularizers 43from tensorflow.python.keras.engine import keras_tensor 44from tensorflow.python.keras.engine.base_layer import Layer 45from tensorflow.python.keras.engine.input_spec import InputSpec 46from tensorflow.python.keras.layers.ops import core as core_ops 47from tensorflow.python.keras.utils import control_flow_util 48from tensorflow.python.keras.utils import conv_utils 49from tensorflow.python.keras.utils import generic_utils 50from tensorflow.python.keras.utils import tf_inspect 51from tensorflow.python.keras.utils import tf_utils 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import math_ops 54from tensorflow.python.ops import nn 55from tensorflow.python.ops import variable_scope 56from tensorflow.python.ops.ragged import ragged_tensor 57from tensorflow.python.platform import tf_logging 58from tensorflow.python.training.tracking import base as trackable 59from tensorflow.python.util import dispatch 60from tensorflow.python.util import nest 61from tensorflow.python.util import tf_decorator 62from tensorflow.python.util.tf_export import get_canonical_name_for_symbol 63from tensorflow.python.util.tf_export import get_symbol_from_name 64from tensorflow.python.util.tf_export import keras_export 65 66# TODO(b/168039935): track dropout rate to decide whether/how to make a 67# dropout rate fastpath. 68keras_temporary_dropout_rate = monitoring.BoolGauge( 69 '/tensorflow/api/keras/dropout/temp_rate_is_zero', 70 'Temporarily record if Keras dropout layer was created w/' 71 'constant rate = 0') 72 73 74# pylint: disable=g-classes-have-attributes 75@keras_export('keras.layers.Masking') 76class Masking(Layer): 77 """Masks a sequence by using a mask value to skip timesteps. 78 79 For each timestep in the input tensor (dimension #1 in the tensor), 80 if all values in the input tensor at that timestep 81 are equal to `mask_value`, then the timestep will be masked (skipped) 82 in all downstream layers (as long as they support masking). 83 84 If any downstream layer does not support masking yet receives such 85 an input mask, an exception will be raised. 86 87 Example: 88 89 Consider a Numpy data array `x` of shape `(samples, timesteps, features)`, 90 to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you 91 lack data for these timesteps. You can: 92 93 - Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.` 94 - Insert a `Masking` layer with `mask_value=0.` before the LSTM layer: 95 96 ```python 97 samples, timesteps, features = 32, 10, 8 98 inputs = np.random.random([samples, timesteps, features]).astype(np.float32) 99 inputs[:, 3, :] = 0. 100 inputs[:, 5, :] = 0. 101 102 model = tf.keras.models.Sequential() 103 model.add(tf.keras.layers.Masking(mask_value=0., 104 input_shape=(timesteps, features))) 105 model.add(tf.keras.layers.LSTM(32)) 106 107 output = model(inputs) 108 # The time step 3 and 5 will be skipped from LSTM calculation. 109 ``` 110 111 See [the masking and padding guide]( 112 https://www.tensorflow.org/guide/keras/masking_and_padding) 113 for more details. 114 """ 115 116 def __init__(self, mask_value=0., **kwargs): 117 super(Masking, self).__init__(**kwargs) 118 self.supports_masking = True 119 self.mask_value = mask_value 120 self._compute_output_and_mask_jointly = True 121 122 def compute_mask(self, inputs, mask=None): 123 return K.any(math_ops.not_equal(inputs, self.mask_value), axis=-1) 124 125 def call(self, inputs): 126 boolean_mask = K.any( 127 math_ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True) 128 outputs = inputs * math_ops.cast(boolean_mask, inputs.dtype) 129 # Compute the mask and outputs simultaneously. 130 outputs._keras_mask = array_ops.squeeze(boolean_mask, axis=-1) # pylint: disable=protected-access 131 return outputs 132 133 def compute_output_shape(self, input_shape): 134 return input_shape 135 136 def get_config(self): 137 config = {'mask_value': self.mask_value} 138 base_config = super(Masking, self).get_config() 139 return dict(list(base_config.items()) + list(config.items())) 140 141 142@keras_export('keras.layers.Dropout') 143class Dropout(Layer): 144 """Applies Dropout to the input. 145 146 The Dropout layer randomly sets input units to 0 with a frequency of `rate` 147 at each step during training time, which helps prevent overfitting. 148 Inputs not set to 0 are scaled up by 1/(1 - rate) such that the sum over 149 all inputs is unchanged. 150 151 Note that the Dropout layer only applies when `training` is set to True 152 such that no values are dropped during inference. When using `model.fit`, 153 `training` will be appropriately set to True automatically, and in other 154 contexts, you can set the kwarg explicitly to True when calling the layer. 155 156 (This is in contrast to setting `trainable=False` for a Dropout layer. 157 `trainable` does not affect the layer's behavior, as Dropout does 158 not have any variables/weights that can be frozen during training.) 159 160 >>> tf.random.set_seed(0) 161 >>> layer = tf.keras.layers.Dropout(.2, input_shape=(2,)) 162 >>> data = np.arange(10).reshape(5, 2).astype(np.float32) 163 >>> print(data) 164 [[0. 1.] 165 [2. 3.] 166 [4. 5.] 167 [6. 7.] 168 [8. 9.]] 169 >>> outputs = layer(data, training=True) 170 >>> print(outputs) 171 tf.Tensor( 172 [[ 0. 1.25] 173 [ 2.5 3.75] 174 [ 5. 6.25] 175 [ 7.5 8.75] 176 [10. 0. ]], shape=(5, 2), dtype=float32) 177 178 Args: 179 rate: Float between 0 and 1. Fraction of the input units to drop. 180 noise_shape: 1D integer tensor representing the shape of the 181 binary dropout mask that will be multiplied with the input. 182 For instance, if your inputs have shape 183 `(batch_size, timesteps, features)` and 184 you want the dropout mask to be the same for all timesteps, 185 you can use `noise_shape=(batch_size, 1, features)`. 186 seed: A Python integer to use as random seed. 187 188 Call arguments: 189 inputs: Input tensor (of any rank). 190 training: Python boolean indicating whether the layer should behave in 191 training mode (adding dropout) or in inference mode (doing nothing). 192 """ 193 194 def __init__(self, rate, noise_shape=None, seed=None, **kwargs): 195 super(Dropout, self).__init__(**kwargs) 196 self.rate = rate 197 if isinstance(rate, (int, float)) and not rate: 198 keras_temporary_dropout_rate.get_cell().set(True) 199 else: 200 keras_temporary_dropout_rate.get_cell().set(False) 201 self.noise_shape = noise_shape 202 self.seed = seed 203 self.supports_masking = True 204 205 def _get_noise_shape(self, inputs): 206 # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`, 207 # which will override `self.noise_shape`, and allows for custom noise 208 # shapes with dynamically sized inputs. 209 if self.noise_shape is None: 210 return None 211 212 concrete_inputs_shape = array_ops.shape(inputs) 213 noise_shape = [] 214 for i, value in enumerate(self.noise_shape): 215 noise_shape.append(concrete_inputs_shape[i] if value is None else value) 216 return ops.convert_to_tensor_v2_with_dispatch(noise_shape) 217 218 def call(self, inputs, training=None): 219 if training is None: 220 training = K.learning_phase() 221 222 def dropped_inputs(): 223 return nn.dropout( 224 inputs, 225 noise_shape=self._get_noise_shape(inputs), 226 seed=self.seed, 227 rate=self.rate) 228 229 output = control_flow_util.smart_cond(training, dropped_inputs, 230 lambda: array_ops.identity(inputs)) 231 return output 232 233 def compute_output_shape(self, input_shape): 234 return input_shape 235 236 def get_config(self): 237 config = { 238 'rate': self.rate, 239 'noise_shape': self.noise_shape, 240 'seed': self.seed 241 } 242 base_config = super(Dropout, self).get_config() 243 return dict(list(base_config.items()) + list(config.items())) 244 245 246@keras_export('keras.layers.SpatialDropout1D') 247class SpatialDropout1D(Dropout): 248 """Spatial 1D version of Dropout. 249 250 This version performs the same function as Dropout, however, it drops 251 entire 1D feature maps instead of individual elements. If adjacent frames 252 within feature maps are strongly correlated (as is normally the case in 253 early convolution layers) then regular dropout will not regularize the 254 activations and will otherwise just result in an effective learning rate 255 decrease. In this case, SpatialDropout1D will help promote independence 256 between feature maps and should be used instead. 257 258 Args: 259 rate: Float between 0 and 1. Fraction of the input units to drop. 260 261 Call arguments: 262 inputs: A 3D tensor. 263 training: Python boolean indicating whether the layer should behave in 264 training mode (adding dropout) or in inference mode (doing nothing). 265 266 Input shape: 267 3D tensor with shape: 268 `(samples, timesteps, channels)` 269 270 Output shape: 271 Same as input. 272 273 References: 274 - [Efficient Object Localization Using Convolutional 275 Networks](https://arxiv.org/abs/1411.4280) 276 """ 277 278 def __init__(self, rate, **kwargs): 279 super(SpatialDropout1D, self).__init__(rate, **kwargs) 280 self.input_spec = InputSpec(ndim=3) 281 282 def _get_noise_shape(self, inputs): 283 input_shape = array_ops.shape(inputs) 284 noise_shape = (input_shape[0], 1, input_shape[2]) 285 return noise_shape 286 287 288@keras_export('keras.layers.SpatialDropout2D') 289class SpatialDropout2D(Dropout): 290 """Spatial 2D version of Dropout. 291 292 This version performs the same function as Dropout, however, it drops 293 entire 2D feature maps instead of individual elements. If adjacent pixels 294 within feature maps are strongly correlated (as is normally the case in 295 early convolution layers) then regular dropout will not regularize the 296 activations and will otherwise just result in an effective learning rate 297 decrease. In this case, SpatialDropout2D will help promote independence 298 between feature maps and should be used instead. 299 300 Args: 301 rate: Float between 0 and 1. Fraction of the input units to drop. 302 data_format: 'channels_first' or 'channels_last'. 303 In 'channels_first' mode, the channels dimension 304 (the depth) is at index 1, 305 in 'channels_last' mode is it at index 3. 306 It defaults to the `image_data_format` value found in your 307 Keras config file at `~/.keras/keras.json`. 308 If you never set it, then it will be "channels_last". 309 310 Call arguments: 311 inputs: A 4D tensor. 312 training: Python boolean indicating whether the layer should behave in 313 training mode (adding dropout) or in inference mode (doing nothing). 314 315 Input shape: 316 4D tensor with shape: 317 `(samples, channels, rows, cols)` if data_format='channels_first' 318 or 4D tensor with shape: 319 `(samples, rows, cols, channels)` if data_format='channels_last'. 320 321 Output shape: 322 Same as input. 323 324 References: 325 - [Efficient Object Localization Using Convolutional 326 Networks](https://arxiv.org/abs/1411.4280) 327 """ 328 329 def __init__(self, rate, data_format=None, **kwargs): 330 super(SpatialDropout2D, self).__init__(rate, **kwargs) 331 if data_format is None: 332 data_format = K.image_data_format() 333 if data_format not in {'channels_last', 'channels_first'}: 334 raise ValueError('data_format must be in ' 335 '{"channels_last", "channels_first"}') 336 self.data_format = data_format 337 self.input_spec = InputSpec(ndim=4) 338 339 def _get_noise_shape(self, inputs): 340 input_shape = array_ops.shape(inputs) 341 if self.data_format == 'channels_first': 342 return (input_shape[0], input_shape[1], 1, 1) 343 elif self.data_format == 'channels_last': 344 return (input_shape[0], 1, 1, input_shape[3]) 345 346 347@keras_export('keras.layers.SpatialDropout3D') 348class SpatialDropout3D(Dropout): 349 """Spatial 3D version of Dropout. 350 351 This version performs the same function as Dropout, however, it drops 352 entire 3D feature maps instead of individual elements. If adjacent voxels 353 within feature maps are strongly correlated (as is normally the case in 354 early convolution layers) then regular dropout will not regularize the 355 activations and will otherwise just result in an effective learning rate 356 decrease. In this case, SpatialDropout3D will help promote independence 357 between feature maps and should be used instead. 358 359 Args: 360 rate: Float between 0 and 1. Fraction of the input units to drop. 361 data_format: 'channels_first' or 'channels_last'. 362 In 'channels_first' mode, the channels dimension (the depth) 363 is at index 1, in 'channels_last' mode is it at index 4. 364 It defaults to the `image_data_format` value found in your 365 Keras config file at `~/.keras/keras.json`. 366 If you never set it, then it will be "channels_last". 367 368 Call arguments: 369 inputs: A 5D tensor. 370 training: Python boolean indicating whether the layer should behave in 371 training mode (adding dropout) or in inference mode (doing nothing). 372 373 Input shape: 374 5D tensor with shape: 375 `(samples, channels, dim1, dim2, dim3)` if data_format='channels_first' 376 or 5D tensor with shape: 377 `(samples, dim1, dim2, dim3, channels)` if data_format='channels_last'. 378 379 Output shape: 380 Same as input. 381 382 References: 383 - [Efficient Object Localization Using Convolutional 384 Networks](https://arxiv.org/abs/1411.4280) 385 """ 386 387 def __init__(self, rate, data_format=None, **kwargs): 388 super(SpatialDropout3D, self).__init__(rate, **kwargs) 389 if data_format is None: 390 data_format = K.image_data_format() 391 if data_format not in {'channels_last', 'channels_first'}: 392 raise ValueError('data_format must be in ' 393 '{"channels_last", "channels_first"}') 394 self.data_format = data_format 395 self.input_spec = InputSpec(ndim=5) 396 397 def _get_noise_shape(self, inputs): 398 input_shape = array_ops.shape(inputs) 399 if self.data_format == 'channels_first': 400 return (input_shape[0], input_shape[1], 1, 1, 1) 401 elif self.data_format == 'channels_last': 402 return (input_shape[0], 1, 1, 1, input_shape[4]) 403 404 405@keras_export('keras.layers.Activation') 406class Activation(Layer): 407 """Applies an activation function to an output. 408 409 Args: 410 activation: Activation function, such as `tf.nn.relu`, or string name of 411 built-in activation function, such as "relu". 412 413 Usage: 414 415 >>> layer = tf.keras.layers.Activation('relu') 416 >>> output = layer([-3.0, -1.0, 0.0, 2.0]) 417 >>> list(output.numpy()) 418 [0.0, 0.0, 0.0, 2.0] 419 >>> layer = tf.keras.layers.Activation(tf.nn.relu) 420 >>> output = layer([-3.0, -1.0, 0.0, 2.0]) 421 >>> list(output.numpy()) 422 [0.0, 0.0, 0.0, 2.0] 423 424 Input shape: 425 Arbitrary. Use the keyword argument `input_shape` 426 (tuple of integers, does not include the batch axis) 427 when using this layer as the first layer in a model. 428 429 Output shape: 430 Same shape as input. 431 """ 432 433 def __init__(self, activation, **kwargs): 434 super(Activation, self).__init__(**kwargs) 435 self.supports_masking = True 436 self.activation = activations.get(activation) 437 438 def call(self, inputs): 439 return self.activation(inputs) 440 441 def compute_output_shape(self, input_shape): 442 return input_shape 443 444 def get_config(self): 445 config = {'activation': activations.serialize(self.activation)} 446 base_config = super(Activation, self).get_config() 447 return dict(list(base_config.items()) + list(config.items())) 448 449 450@keras_export('keras.layers.Reshape') 451class Reshape(Layer): 452 """Layer that reshapes inputs into the given shape. 453 454 Input shape: 455 Arbitrary, although all dimensions in the input shape must be known/fixed. 456 Use the keyword argument `input_shape` (tuple of integers, does not include 457 the samples/batch size axis) when using this layer as the first layer 458 in a model. 459 460 Output shape: 461 `(batch_size,) + target_shape` 462 463 Example: 464 465 >>> # as first layer in a Sequential model 466 >>> model = tf.keras.Sequential() 467 >>> model.add(tf.keras.layers.Reshape((3, 4), input_shape=(12,))) 468 >>> # model.output_shape == (None, 3, 4), `None` is the batch size. 469 >>> model.output_shape 470 (None, 3, 4) 471 472 >>> # as intermediate layer in a Sequential model 473 >>> model.add(tf.keras.layers.Reshape((6, 2))) 474 >>> model.output_shape 475 (None, 6, 2) 476 477 >>> # also supports shape inference using `-1` as dimension 478 >>> model.add(tf.keras.layers.Reshape((-1, 2, 2))) 479 >>> model.output_shape 480 (None, 3, 2, 2) 481 """ 482 483 def __init__(self, target_shape, **kwargs): 484 """Creates a `tf.keras.layers.Reshape` layer instance. 485 486 Args: 487 target_shape: Target shape. Tuple of integers, does not include the 488 samples dimension (batch size). 489 **kwargs: Any additional layer keyword arguments. 490 """ 491 super(Reshape, self).__init__(**kwargs) 492 self.target_shape = tuple(target_shape) 493 494 def _fix_unknown_dimension(self, input_shape, output_shape): 495 """Find and replace a missing dimension in an output shape. 496 497 This is a near direct port of the internal Numpy function 498 `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c` 499 500 Args: 501 input_shape: Shape of array being reshaped 502 output_shape: Desired shape of the array with at most 503 a single -1 which indicates a dimension that should be 504 derived from the input shape. 505 506 Returns: 507 The new output shape with a -1 replaced with its computed value. 508 509 Raises: 510 ValueError: If the total array size of the output_shape is 511 different than the input_shape, or more than one unknown dimension 512 is specified. 513 """ 514 output_shape = list(output_shape) 515 msg = ('total size of new array must be unchanged, ' 516 'input_shape = {}, output_shape = {}' 517 .format(input_shape, output_shape)) 518 519 known, unknown = 1, None 520 for index, dim in enumerate(output_shape): 521 if dim < 0: 522 if unknown is None: 523 unknown = index 524 else: 525 raise ValueError('Can only specify one unknown dimension.') 526 else: 527 known *= dim 528 529 original = np.prod(input_shape, dtype=int) 530 if unknown is not None: 531 if known == 0 or original % known != 0: 532 raise ValueError(msg) 533 output_shape[unknown] = original // known 534 elif original != known: 535 raise ValueError(msg) 536 return output_shape 537 538 def compute_output_shape(self, input_shape): 539 input_shape = tensor_shape.TensorShape(input_shape).as_list() 540 if None in input_shape[1:]: 541 output_shape = [input_shape[0]] 542 # input shape (partially) unknown? replace -1's with None's 543 output_shape += tuple(s if s != -1 else None for s in self.target_shape) 544 else: 545 output_shape = [input_shape[0]] 546 output_shape += self._fix_unknown_dimension(input_shape[1:], 547 self.target_shape) 548 return tensor_shape.TensorShape(output_shape) 549 550 def call(self, inputs): 551 result = array_ops.reshape( 552 inputs, (array_ops.shape(inputs)[0],) + self.target_shape) 553 if not context.executing_eagerly(): 554 # Set the static shape for the result since it might lost during array_ops 555 # reshape, eg, some `None` dim in the result could be inferred. 556 result.set_shape(self.compute_output_shape(inputs.shape)) 557 return result 558 559 def get_config(self): 560 config = {'target_shape': self.target_shape} 561 base_config = super(Reshape, self).get_config() 562 return dict(list(base_config.items()) + list(config.items())) 563 564 565@keras_export('keras.layers.Permute') 566class Permute(Layer): 567 """Permutes the dimensions of the input according to a given pattern. 568 569 Useful e.g. connecting RNNs and convnets. 570 571 Example: 572 573 ```python 574 model = Sequential() 575 model.add(Permute((2, 1), input_shape=(10, 64))) 576 # now: model.output_shape == (None, 64, 10) 577 # note: `None` is the batch dimension 578 ``` 579 580 Args: 581 dims: Tuple of integers. Permutation pattern does not include the 582 samples dimension. Indexing starts at 1. 583 For instance, `(2, 1)` permutes the first and second dimensions 584 of the input. 585 586 Input shape: 587 Arbitrary. Use the keyword argument `input_shape` 588 (tuple of integers, does not include the samples axis) 589 when using this layer as the first layer in a model. 590 591 Output shape: 592 Same as the input shape, but with the dimensions re-ordered according 593 to the specified pattern. 594 """ 595 596 def __init__(self, dims, **kwargs): 597 super(Permute, self).__init__(**kwargs) 598 self.dims = tuple(dims) 599 if sorted(dims) != list(range(1, len(dims) + 1)): 600 raise ValueError( 601 'Invalid permutation `dims` for Permute Layer: %s. ' 602 'The set of indices in `dims` must be consecutive and start from 1.' % 603 (dims,)) 604 self.input_spec = InputSpec(ndim=len(self.dims) + 1) 605 606 def compute_output_shape(self, input_shape): 607 input_shape = tensor_shape.TensorShape(input_shape).as_list() 608 output_shape = copy.copy(input_shape) 609 for i, dim in enumerate(self.dims): 610 target_dim = input_shape[dim] 611 output_shape[i + 1] = target_dim 612 return tensor_shape.TensorShape(output_shape) 613 614 def call(self, inputs): 615 return array_ops.transpose(inputs, perm=(0,) + self.dims) 616 617 def get_config(self): 618 config = {'dims': self.dims} 619 base_config = super(Permute, self).get_config() 620 return dict(list(base_config.items()) + list(config.items())) 621 622 623@keras_export('keras.layers.Flatten') 624class Flatten(Layer): 625 """Flattens the input. Does not affect the batch size. 626 627 Note: If inputs are shaped `(batch,)` without a feature axis, then 628 flattening adds an extra channel dimension and output shape is `(batch, 1)`. 629 630 Args: 631 data_format: A string, 632 one of `channels_last` (default) or `channels_first`. 633 The ordering of the dimensions in the inputs. 634 `channels_last` corresponds to inputs with shape 635 `(batch, ..., channels)` while `channels_first` corresponds to 636 inputs with shape `(batch, channels, ...)`. 637 It defaults to the `image_data_format` value found in your 638 Keras config file at `~/.keras/keras.json`. 639 If you never set it, then it will be "channels_last". 640 641 Example: 642 643 >>> model = tf.keras.Sequential() 644 >>> model.add(tf.keras.layers.Conv2D(64, 3, 3, input_shape=(3, 32, 32))) 645 >>> model.output_shape 646 (None, 1, 10, 64) 647 648 >>> model.add(Flatten()) 649 >>> model.output_shape 650 (None, 640) 651 652 """ 653 654 def __init__(self, data_format=None, **kwargs): 655 super(Flatten, self).__init__(**kwargs) 656 self.data_format = conv_utils.normalize_data_format(data_format) 657 self.input_spec = InputSpec(min_ndim=1) 658 self._channels_first = self.data_format == 'channels_first' 659 660 def call(self, inputs): 661 if self._channels_first: 662 rank = inputs.shape.rank 663 if rank and rank > 1: 664 # Switch to channels-last format. 665 permutation = [0] 666 permutation.extend(range(2, rank)) 667 permutation.append(1) 668 inputs = array_ops.transpose(inputs, perm=permutation) 669 670 if context.executing_eagerly(): 671 # Full static shape is guaranteed to be available. 672 # Performance: Using `constant_op` is much faster than passing a list. 673 flattened_shape = constant_op.constant([inputs.shape[0], -1]) 674 return array_ops.reshape(inputs, flattened_shape) 675 else: 676 input_shape = inputs.shape 677 rank = input_shape.rank 678 if rank == 1: 679 return array_ops.expand_dims_v2(inputs, axis=1) 680 else: 681 batch_dim = tensor_shape.dimension_value(input_shape[0]) 682 non_batch_dims = input_shape[1:] 683 # Reshape in a way that preserves as much shape info as possible. 684 if non_batch_dims.is_fully_defined(): 685 last_dim = int(functools.reduce(operator.mul, non_batch_dims)) 686 flattened_shape = constant_op.constant([-1, last_dim]) 687 elif batch_dim is not None: 688 flattened_shape = constant_op.constant([int(batch_dim), -1]) 689 else: 690 flattened_shape = [array_ops.shape_v2(inputs)[0], -1] 691 return array_ops.reshape(inputs, flattened_shape) 692 693 def compute_output_shape(self, input_shape): 694 input_shape = tensor_shape.TensorShape(input_shape).as_list() 695 if not input_shape: 696 output_shape = tensor_shape.TensorShape([1]) 697 else: 698 output_shape = [input_shape[0]] 699 if np.all(input_shape[1:]): 700 output_shape += [np.prod(input_shape[1:], dtype=int)] 701 else: 702 output_shape += [None] 703 return tensor_shape.TensorShape(output_shape) 704 705 def get_config(self): 706 config = super(Flatten, self).get_config() 707 config.update({'data_format': self.data_format}) 708 return config 709 710 711@keras_export('keras.layers.RepeatVector') 712class RepeatVector(Layer): 713 """Repeats the input n times. 714 715 Example: 716 717 ```python 718 model = Sequential() 719 model.add(Dense(32, input_dim=32)) 720 # now: model.output_shape == (None, 32) 721 # note: `None` is the batch dimension 722 723 model.add(RepeatVector(3)) 724 # now: model.output_shape == (None, 3, 32) 725 ``` 726 727 Args: 728 n: Integer, repetition factor. 729 730 Input shape: 731 2D tensor of shape `(num_samples, features)`. 732 733 Output shape: 734 3D tensor of shape `(num_samples, n, features)`. 735 """ 736 737 def __init__(self, n, **kwargs): 738 super(RepeatVector, self).__init__(**kwargs) 739 self.n = n 740 self.input_spec = InputSpec(ndim=2) 741 742 def compute_output_shape(self, input_shape): 743 input_shape = tensor_shape.TensorShape(input_shape).as_list() 744 return tensor_shape.TensorShape([input_shape[0], self.n, input_shape[1]]) 745 746 def call(self, inputs): 747 return K.repeat(inputs, self.n) 748 749 def get_config(self): 750 config = {'n': self.n} 751 base_config = super(RepeatVector, self).get_config() 752 return dict(list(base_config.items()) + list(config.items())) 753 754 755@keras_export('keras.layers.Lambda') 756class Lambda(Layer): 757 """Wraps arbitrary expressions as a `Layer` object. 758 759 The `Lambda` layer exists so that arbitrary expressions can be used 760 as a `Layer` when constructing `Sequential` 761 and Functional API models. `Lambda` layers are best suited for simple 762 operations or quick experimentation. For more advanced use cases, follow 763 [this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models) 764 for subclassing `tf.keras.layers.Layer`. 765 766 WARNING: `tf.keras.layers.Lambda` layers have (de)serialization limitations! 767 768 The main reason to subclass `tf.keras.layers.Layer` instead of using a 769 `Lambda` layer is saving and inspecting a Model. `Lambda` layers 770 are saved by serializing the Python bytecode, which is fundamentally 771 non-portable. They should only be loaded in the same environment where 772 they were saved. Subclassed layers can be saved in a more portable way 773 by overriding their `get_config` method. Models that rely on 774 subclassed Layers are also often easier to visualize and reason about. 775 776 Examples: 777 778 ```python 779 # add a x -> x^2 layer 780 model.add(Lambda(lambda x: x ** 2)) 781 ``` 782 ```python 783 # add a layer that returns the concatenation 784 # of the positive part of the input and 785 # the opposite of the negative part 786 787 def antirectifier(x): 788 x -= K.mean(x, axis=1, keepdims=True) 789 x = K.l2_normalize(x, axis=1) 790 pos = K.relu(x) 791 neg = K.relu(-x) 792 return K.concatenate([pos, neg], axis=1) 793 794 model.add(Lambda(antirectifier)) 795 ``` 796 797 Variables: 798 While it is possible to use Variables with Lambda layers, this practice is 799 discouraged as it can easily lead to bugs. For instance, consider the 800 following layer: 801 802 ```python 803 scale = tf.Variable(1.) 804 scale_layer = tf.keras.layers.Lambda(lambda x: x * scale) 805 ``` 806 807 Because scale_layer does not directly track the `scale` variable, it will 808 not appear in `scale_layer.trainable_weights` and will therefore not be 809 trained if `scale_layer` is used in a Model. 810 811 A better pattern is to write a subclassed Layer: 812 813 ```python 814 class ScaleLayer(tf.keras.layers.Layer): 815 def __init__(self): 816 super(ScaleLayer, self).__init__() 817 self.scale = tf.Variable(1.) 818 819 def call(self, inputs): 820 return inputs * self.scale 821 ``` 822 823 In general, Lambda layers can be convenient for simple stateless 824 computation, but anything more complex should use a subclass Layer instead. 825 826 Args: 827 function: The function to be evaluated. Takes input tensor as first 828 argument. 829 output_shape: Expected output shape from function. This argument can be 830 inferred if not explicitly provided. Can be a tuple or function. If a 831 tuple, it only specifies the first dimension onward; 832 sample dimension is assumed either the same as the input: `output_shape = 833 (input_shape[0], ) + output_shape` or, the input is `None` and 834 the sample dimension is also `None`: `output_shape = (None, ) + 835 output_shape` If a function, it specifies the entire shape as a function 836 of the 837 input shape: `output_shape = f(input_shape)` 838 mask: Either None (indicating no masking) or a callable with the same 839 signature as the `compute_mask` layer method, or a tensor that will be 840 returned as output mask regardless of what the input is. 841 arguments: Optional dictionary of keyword arguments to be passed to the 842 function. 843 844 Input shape: 845 Arbitrary. Use the keyword argument input_shape (tuple of 846 integers, does not include the samples axis) when using this layer as the 847 first layer in a model. 848 849 Output shape: 850 Specified by `output_shape` argument 851 """ 852 853 @trackable.no_automatic_dependency_tracking 854 def __init__(self, function, output_shape=None, mask=None, arguments=None, 855 **kwargs): 856 super(Lambda, self).__init__(**kwargs) 857 858 self.arguments = arguments or {} 859 self.function = function 860 861 if mask is not None: 862 self.supports_masking = True 863 self.mask = mask 864 self._output_shape = output_shape 865 866 # Warning on every invocation will be quite irksome in Eager mode. 867 self._already_warned = False 868 869 function_args = tf_inspect.getfullargspec(function).args 870 self._fn_expects_training_arg = 'training' in function_args 871 self._fn_expects_mask_arg = 'mask' in function_args 872 873 @tf_utils.shape_type_conversion 874 def compute_output_shape(self, input_shape): 875 if self._output_shape is None: 876 # Make use of existing autocomputation but provide Lambda-specific 877 # error message. This is always safe to run even when the outer context 878 # is Graph mode because Lambda layers don't have side effects such as 879 # `add_loss`. 880 with context.eager_mode(): 881 try: 882 return super(Lambda, self).compute_output_shape(input_shape) 883 except NotImplementedError: 884 raise NotImplementedError( 885 'We could not automatically infer the shape of the Lambda\'s ' 886 'output. Please specify `output_shape` for this Lambda.') 887 888 if callable(self._output_shape): 889 output_shapes = self._output_shape(input_shape) 890 return tf_utils.convert_shapes(output_shapes, to_tuples=False) 891 892 # Output shapes are passed directly and don't include batch dimension. 893 input_tensor_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 894 batch_size = nest.flatten(input_tensor_shape)[0][0] if input_shape else None 895 896 def _add_batch(shape): 897 return tensor_shape.TensorShape([batch_size] + shape.as_list()) 898 899 output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) 900 return nest.map_structure(_add_batch, output_shapes) 901 902 def call(self, inputs, mask=None, training=None): 903 # We must copy for thread safety, but it only needs to be a shallow copy. 904 kwargs = {k: v for k, v in self.arguments.items()} 905 if self._fn_expects_mask_arg: 906 kwargs['mask'] = mask 907 if self._fn_expects_training_arg: 908 kwargs['training'] = training 909 910 created_variables = [] 911 def _variable_creator(next_creator, **kwargs): 912 var = next_creator(**kwargs) 913 created_variables.append(var) 914 return var 915 916 with backprop.GradientTape(watch_accessed_variables=True) as tape,\ 917 variable_scope.variable_creator_scope(_variable_creator): 918 result = self.function(inputs, **kwargs) 919 self._check_variables(created_variables, tape.watched_variables()) 920 return result 921 922 def _check_variables(self, created_variables, accessed_variables): 923 if not created_variables and not accessed_variables: 924 # In the common case that a Lambda layer does not touch a Variable, we 925 # don't want to incur the runtime cost of assembling any state used for 926 # checking only to immediately discard it. 927 return 928 929 tracked_weights = set(v.ref() for v in self.weights) 930 untracked_new_vars = [ 931 v for v in created_variables if v.ref() not in tracked_weights 932 ] 933 if untracked_new_vars: 934 variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars) 935 error_str = textwrap.dedent( 936 ''' 937 The following Variables were created within a Lambda layer ({name}) 938 but are not tracked by said layer: 939 {variable_str} 940 The layer cannot safely ensure proper Variable reuse across multiple 941 calls, and consquently this behavior is disallowed for safety. Lambda 942 layers are not well suited to stateful computation; instead, writing a 943 subclassed Layer is the recommend way to define layers with 944 Variables.''' 945 ).format(name=self.name, variable_str=variable_str) 946 raise ValueError(error_str) 947 948 untracked_used_vars = [ 949 v for v in accessed_variables if v.ref() not in tracked_weights 950 ] 951 if untracked_used_vars and not self._already_warned: 952 variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars) 953 self._warn(textwrap.dedent( 954 ''' 955 The following Variables were used a Lambda layer's call ({name}), but 956 are not present in its tracked objects: 957 {variable_str} 958 It is possible that this is intended behavior, but it is more likely 959 an omission. This is a strong indication that this layer should be 960 formulated as a subclassed Layer rather than a Lambda layer.''' 961 ).format(name=self.name, variable_str=variable_str)) 962 self._already_warned = True 963 964 def _warn(self, msg): 965 # This method will be overridden in a unit test to raise an error, because 966 # self.assertWarns is not universally implemented. 967 return tf_logging.warn(msg) 968 969 def compute_mask(self, inputs, mask=None): 970 if callable(self.mask): 971 return self.mask(inputs, mask) 972 return self.mask 973 974 def get_config(self): 975 function_config = self._serialize_function_to_config(self.function) 976 output_shape_config = self._serialize_function_to_config(self._output_shape, 977 allow_raw=True) 978 config = { 979 'function': function_config[0], 980 'function_type': function_config[1], 981 'module': function_config[2], 982 'output_shape': output_shape_config[0], 983 'output_shape_type': output_shape_config[1], 984 'output_shape_module': output_shape_config[2], 985 } 986 if self.mask is not None: 987 mask_config = self._serialize_function_to_config(self.mask) 988 config.update({ 989 'mask': mask_config[0], 990 'mask_type': mask_config[1], 991 'mask_module': mask_config[2] 992 }) 993 config['arguments'] = self.arguments 994 995 base_config = super(Lambda, self).get_config() 996 return dict(list(base_config.items()) + list(config.items())) 997 998 def _serialize_function_to_config(self, inputs, allow_raw=False): 999 if isinstance(inputs, python_types.LambdaType): 1000 output = generic_utils.func_dump(inputs) 1001 output_type = 'lambda' 1002 module = inputs.__module__ 1003 elif callable(inputs): 1004 output = inputs.__name__ 1005 output_type = 'function' 1006 module = inputs.__module__ 1007 elif allow_raw: 1008 output = inputs 1009 output_type = 'raw' 1010 module = None 1011 else: 1012 raise ValueError( 1013 'Invalid input for serialization, type: %s ' % type(inputs)) 1014 1015 return output, output_type, module 1016 1017 @classmethod 1018 def from_config(cls, config, custom_objects=None): 1019 config = config.copy() 1020 function = cls._parse_function_from_config( 1021 config, custom_objects, 'function', 'module', 'function_type') 1022 1023 output_shape = cls._parse_function_from_config( 1024 config, custom_objects, 'output_shape', 'output_shape_module', 1025 'output_shape_type') 1026 if 'mask' in config: 1027 mask = cls._parse_function_from_config( 1028 config, custom_objects, 'mask', 'mask_module', 'mask_type') 1029 else: 1030 mask = None 1031 1032 config['function'] = function 1033 config['output_shape'] = output_shape 1034 config['mask'] = mask 1035 1036 # If arguments were numpy array, they have been saved as 1037 # list. We need to recover the ndarray 1038 if 'arguments' in config: 1039 for key in config['arguments']: 1040 if isinstance(config['arguments'][key], dict): 1041 arg_dict = config['arguments'][key] 1042 if 'type' in arg_dict and arg_dict['type'] == 'ndarray': 1043 # Overwrite the argument with its numpy translation 1044 config['arguments'][key] = np.array(arg_dict['value']) 1045 1046 return cls(**config) 1047 1048 @classmethod 1049 def _parse_function_from_config( 1050 cls, config, custom_objects, func_attr_name, module_attr_name, 1051 func_type_attr_name): 1052 globs = globals().copy() 1053 module = config.pop(module_attr_name, None) 1054 if module in sys.modules: 1055 globs.update(sys.modules[module].__dict__) 1056 elif module is not None: 1057 # Note: we don't know the name of the function if it's a lambda. 1058 warnings.warn('{} is not loaded, but a Lambda layer uses it. ' 1059 'It may cause errors.'.format(module) 1060 , UserWarning) 1061 if custom_objects: 1062 globs.update(custom_objects) 1063 function_type = config.pop(func_type_attr_name) 1064 if function_type == 'function': 1065 # Simple lookup in custom objects 1066 function = generic_utils.deserialize_keras_object( 1067 config[func_attr_name], 1068 custom_objects=custom_objects, 1069 printable_module_name='function in Lambda layer') 1070 elif function_type == 'lambda': 1071 # Unsafe deserialization from bytecode 1072 function = generic_utils.func_load( 1073 config[func_attr_name], globs=globs) 1074 elif function_type == 'raw': 1075 function = config[func_attr_name] 1076 else: 1077 raise TypeError('Unknown function type:', function_type) 1078 return function 1079 1080 1081@keras_export('keras.layers.Dense') 1082class Dense(Layer): 1083 """Just your regular densely-connected NN layer. 1084 1085 `Dense` implements the operation: 1086 `output = activation(dot(input, kernel) + bias)` 1087 where `activation` is the element-wise activation function 1088 passed as the `activation` argument, `kernel` is a weights matrix 1089 created by the layer, and `bias` is a bias vector created by the layer 1090 (only applicable if `use_bias` is `True`). 1091 1092 Note: If the input to the layer has a rank greater than 2, then `Dense` 1093 computes the dot product between the `inputs` and the `kernel` along the 1094 last axis of the `inputs` and axis 1 of the `kernel` (using `tf.tensordot`). 1095 For example, if input has dimensions `(batch_size, d0, d1)`, 1096 then we create a `kernel` with shape `(d1, units)`, and the `kernel` operates 1097 along axis 2 of the `input`, on every sub-tensor of shape `(1, 1, d1)` 1098 (there are `batch_size * d0` such sub-tensors). 1099 The output in this case will have shape `(batch_size, d0, units)`. 1100 1101 Besides, layer attributes cannot be modified after the layer has been called 1102 once (except the `trainable` attribute). 1103 1104 Example: 1105 1106 >>> # Create a `Sequential` model and add a Dense layer as the first layer. 1107 >>> model = tf.keras.models.Sequential() 1108 >>> model.add(tf.keras.Input(shape=(16,))) 1109 >>> model.add(tf.keras.layers.Dense(32, activation='relu')) 1110 >>> # Now the model will take as input arrays of shape (None, 16) 1111 >>> # and output arrays of shape (None, 32). 1112 >>> # Note that after the first layer, you don't need to specify 1113 >>> # the size of the input anymore: 1114 >>> model.add(tf.keras.layers.Dense(32)) 1115 >>> model.output_shape 1116 (None, 32) 1117 1118 Args: 1119 units: Positive integer, dimensionality of the output space. 1120 activation: Activation function to use. 1121 If you don't specify anything, no activation is applied 1122 (ie. "linear" activation: `a(x) = x`). 1123 use_bias: Boolean, whether the layer uses a bias vector. 1124 kernel_initializer: Initializer for the `kernel` weights matrix. 1125 bias_initializer: Initializer for the bias vector. 1126 kernel_regularizer: Regularizer function applied to 1127 the `kernel` weights matrix. 1128 bias_regularizer: Regularizer function applied to the bias vector. 1129 activity_regularizer: Regularizer function applied to 1130 the output of the layer (its "activation"). 1131 kernel_constraint: Constraint function applied to 1132 the `kernel` weights matrix. 1133 bias_constraint: Constraint function applied to the bias vector. 1134 1135 Input shape: 1136 N-D tensor with shape: `(batch_size, ..., input_dim)`. 1137 The most common situation would be 1138 a 2D input with shape `(batch_size, input_dim)`. 1139 1140 Output shape: 1141 N-D tensor with shape: `(batch_size, ..., units)`. 1142 For instance, for a 2D input with shape `(batch_size, input_dim)`, 1143 the output would have shape `(batch_size, units)`. 1144 """ 1145 1146 def __init__(self, 1147 units, 1148 activation=None, 1149 use_bias=True, 1150 kernel_initializer='glorot_uniform', 1151 bias_initializer='zeros', 1152 kernel_regularizer=None, 1153 bias_regularizer=None, 1154 activity_regularizer=None, 1155 kernel_constraint=None, 1156 bias_constraint=None, 1157 **kwargs): 1158 super(Dense, self).__init__( 1159 activity_regularizer=activity_regularizer, **kwargs) 1160 1161 self.units = int(units) if not isinstance(units, int) else units 1162 self.activation = activations.get(activation) 1163 self.use_bias = use_bias 1164 self.kernel_initializer = initializers.get(kernel_initializer) 1165 self.bias_initializer = initializers.get(bias_initializer) 1166 self.kernel_regularizer = regularizers.get(kernel_regularizer) 1167 self.bias_regularizer = regularizers.get(bias_regularizer) 1168 self.kernel_constraint = constraints.get(kernel_constraint) 1169 self.bias_constraint = constraints.get(bias_constraint) 1170 1171 self.input_spec = InputSpec(min_ndim=2) 1172 self.supports_masking = True 1173 1174 def build(self, input_shape): 1175 dtype = dtypes.as_dtype(self.dtype or K.floatx()) 1176 if not (dtype.is_floating or dtype.is_complex): 1177 raise TypeError('Unable to build `Dense` layer with non-floating point ' 1178 'dtype %s' % (dtype,)) 1179 1180 input_shape = tensor_shape.TensorShape(input_shape) 1181 last_dim = tensor_shape.dimension_value(input_shape[-1]) 1182 if last_dim is None: 1183 raise ValueError('The last dimension of the inputs to `Dense` ' 1184 'should be defined. Found `None`.') 1185 self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim}) 1186 self.kernel = self.add_weight( 1187 'kernel', 1188 shape=[last_dim, self.units], 1189 initializer=self.kernel_initializer, 1190 regularizer=self.kernel_regularizer, 1191 constraint=self.kernel_constraint, 1192 dtype=self.dtype, 1193 trainable=True) 1194 if self.use_bias: 1195 self.bias = self.add_weight( 1196 'bias', 1197 shape=[self.units,], 1198 initializer=self.bias_initializer, 1199 regularizer=self.bias_regularizer, 1200 constraint=self.bias_constraint, 1201 dtype=self.dtype, 1202 trainable=True) 1203 else: 1204 self.bias = None 1205 self.built = True 1206 1207 def call(self, inputs): 1208 return core_ops.dense( 1209 inputs, 1210 self.kernel, 1211 self.bias, 1212 self.activation, 1213 dtype=self._compute_dtype_object) 1214 1215 def compute_output_shape(self, input_shape): 1216 input_shape = tensor_shape.TensorShape(input_shape) 1217 input_shape = input_shape.with_rank_at_least(2) 1218 if tensor_shape.dimension_value(input_shape[-1]) is None: 1219 raise ValueError( 1220 'The innermost dimension of input_shape must be defined, but saw: %s' 1221 % input_shape) 1222 return input_shape[:-1].concatenate(self.units) 1223 1224 def get_config(self): 1225 config = super(Dense, self).get_config() 1226 config.update({ 1227 'units': 1228 self.units, 1229 'activation': 1230 activations.serialize(self.activation), 1231 'use_bias': 1232 self.use_bias, 1233 'kernel_initializer': 1234 initializers.serialize(self.kernel_initializer), 1235 'bias_initializer': 1236 initializers.serialize(self.bias_initializer), 1237 'kernel_regularizer': 1238 regularizers.serialize(self.kernel_regularizer), 1239 'bias_regularizer': 1240 regularizers.serialize(self.bias_regularizer), 1241 'activity_regularizer': 1242 regularizers.serialize(self.activity_regularizer), 1243 'kernel_constraint': 1244 constraints.serialize(self.kernel_constraint), 1245 'bias_constraint': 1246 constraints.serialize(self.bias_constraint) 1247 }) 1248 return config 1249 1250 1251@keras_export('keras.layers.ActivityRegularization') 1252class ActivityRegularization(Layer): 1253 """Layer that applies an update to the cost function based input activity. 1254 1255 Args: 1256 l1: L1 regularization factor (positive float). 1257 l2: L2 regularization factor (positive float). 1258 1259 Input shape: 1260 Arbitrary. Use the keyword argument `input_shape` 1261 (tuple of integers, does not include the samples axis) 1262 when using this layer as the first layer in a model. 1263 1264 Output shape: 1265 Same shape as input. 1266 """ 1267 1268 def __init__(self, l1=0., l2=0., **kwargs): 1269 super(ActivityRegularization, self).__init__( 1270 activity_regularizer=regularizers.L1L2(l1=l1, l2=l2), **kwargs) 1271 self.supports_masking = True 1272 self.l1 = l1 1273 self.l2 = l2 1274 1275 def compute_output_shape(self, input_shape): 1276 return input_shape 1277 1278 def get_config(self): 1279 config = {'l1': self.l1, 'l2': self.l2} 1280 base_config = super(ActivityRegularization, self).get_config() 1281 return dict(list(base_config.items()) + list(config.items())) 1282 1283 1284class TFOpLambda(Layer): 1285 """Wraps TF API symbols in a `Layer` object. 1286 1287 It is inserted by the Functional API construction whenever users call 1288 a supported TF symbol on KerasTensors. 1289 1290 Like Lambda layers, this layer tries to raise warnings when it detects users 1291 explicitly use variables in the call. (To let them know 1292 that the layer will not capture the variables). 1293 1294 This is useful in the case where users do something like: 1295 x = keras.Input(...) 1296 y = tf.Variable(...) 1297 out = x * tf_variable 1298 """ 1299 1300 @trackable.no_automatic_dependency_tracking 1301 def __init__(self, function, **kwargs): 1302 self.function = function 1303 self.symbol = ( 1304 get_canonical_name_for_symbol(self.function, 1305 add_prefix_to_v1_names=True) or 1306 get_canonical_name_for_symbol(self.function, 1307 api_name='keras', 1308 add_prefix_to_v1_names=True)) 1309 if 'name' not in kwargs: 1310 # Generate a name. 1311 # TFOpLambda layers avoid already-observed names, 1312 # because users cannot easily control the generated names. 1313 # Without this avoidance, users would be more likely to run 1314 # into unavoidable duplicate layer name collisions. 1315 # (For standard layers users could just set `name` when creating the 1316 # layer to work around a collision, but they can't do that for 1317 # auto-generated layers) 1318 if self.symbol: 1319 name = 'tf.' + self.symbol 1320 else: 1321 name = self.function.__name__ 1322 kwargs['name'] = K.unique_object_name( 1323 name, zero_based=True, avoid_observed_names=True) 1324 kwargs['autocast'] = False 1325 1326 # Decorate the function to produce this layer's call method 1327 def _call_wrapper(*args, **kwargs): 1328 return self._call_wrapper(*args, **kwargs) 1329 self.call = tf_decorator.make_decorator(function, _call_wrapper) 1330 1331 # Do not individually trace op layers in the SavedModel. 1332 self._must_restore_from_config = True 1333 1334 super(TFOpLambda, self).__init__(**kwargs) 1335 1336 # Preserve all argument data structures when saving/loading a config 1337 # (e.g., don't unnest lists that contain one element) 1338 self._preserve_input_structure_in_config = True 1339 1340 # Warning on every invocation will be quite irksome in Eager mode. 1341 self._already_warned = False 1342 1343 self._expects_training_arg = False 1344 self._expects_mask_arg = False 1345 1346 def _call_wrapper(self, *args, **kwargs): 1347 created_variables = [] 1348 def _variable_creator(next_creator, **creator_kwargs): 1349 var = next_creator(**creator_kwargs) 1350 created_variables.append(var) 1351 return var 1352 1353 with backprop.GradientTape(watch_accessed_variables=True) as tape, \ 1354 variable_scope.variable_creator_scope(_variable_creator): 1355 # We explicitly drop `name` arguments here, 1356 # to guard against the case where an op explicitly has a 1357 # `name` passed (which is susceptible to producing 1358 # multiple ops w/ the same name when the layer is reused) 1359 kwargs.pop('name', None) 1360 result = self.function(*args, **kwargs) 1361 self._check_variables(created_variables, tape.watched_variables()) 1362 return result 1363 1364 def _check_variables(self, created_variables, accessed_variables): 1365 if not created_variables and not accessed_variables: 1366 # In the common case that a Lambda layer does not touch a Variable, we 1367 # don't want to incur the runtime cost of assembling any state used for 1368 # checking only to immediately discard it. 1369 return 1370 1371 tracked_weights = set(v.ref() for v in self.weights) 1372 untracked_new_vars = [ 1373 v for v in created_variables if v.ref() not in tracked_weights 1374 ] 1375 if untracked_new_vars: 1376 variable_str = '\n'.join(' {}'.format(i) for i in untracked_new_vars) 1377 error_str = textwrap.dedent( 1378 ''' 1379 The following Variables were created within a Lambda layer ({name}) 1380 but are not tracked by said layer: 1381 {variable_str} 1382 The layer cannot safely ensure proper Variable reuse across multiple 1383 calls, and consquently this behavior is disallowed for safety. Lambda 1384 layers are not well suited to stateful computation; instead, writing a 1385 subclassed Layer is the recommend way to define layers with 1386 Variables.''' 1387 ).format(name=self.name, variable_str=variable_str) 1388 raise ValueError(error_str) 1389 1390 untracked_used_vars = [ 1391 v for v in accessed_variables if v.ref() not in tracked_weights 1392 ] 1393 if untracked_used_vars and not self._already_warned: 1394 variable_str = '\n'.join(' {}'.format(i) for i in untracked_used_vars) 1395 self._warn(textwrap.dedent( 1396 ''' 1397 The following Variables were used a Lambda layer's call ({name}), but 1398 are not present in its tracked objects: 1399 {variable_str} 1400 It is possible that this is intended behavior, but it is more likely 1401 an omission. This is a strong indication that this layer should be 1402 formulated as a subclassed Layer rather than a Lambda layer.''' 1403 ).format(name=self.name, variable_str=variable_str)) 1404 self._already_warned = True 1405 1406 def _warn(self, msg): 1407 # This method will be overridden in a unit test to raise an error, because 1408 # self.assertWarns is not universally implemented. 1409 return tf_logging.warn(msg) 1410 1411 def get_config(self): 1412 if not self.symbol: 1413 raise ValueError('This Keras op layer was generated from %s, a method ' 1414 'that is not an exposed in the TensorFlow API. This ' 1415 'may have happened if the method was explicitly ' 1416 'decorated to add dispatching support, and it was used ' 1417 'during Functional model construction. ' 1418 'To ensure cross-version compatibility of Keras models ' 1419 'that use op layers, only op layers produced from ' 1420 'exported TF API symbols can be serialized.' 1421 % self.function) 1422 config = { 1423 'function': self.symbol 1424 } 1425 1426 base_config = super(TFOpLambda, self).get_config() 1427 return dict(list(base_config.items()) + list(config.items())) 1428 1429 @classmethod 1430 def from_config(cls, config, custom_objects=None): 1431 config = config.copy() 1432 symbol_name = config['function'] 1433 function = get_symbol_from_name(symbol_name) 1434 if not function: 1435 raise ValueError( 1436 'TF symbol `tf.%s` could not be found.' % symbol_name) 1437 1438 config['function'] = function 1439 1440 return cls(**config) 1441 1442 1443class KerasOpDispatcher(dispatch.GlobalOpDispatcher): 1444 """A global dispatcher that allows building a functional model with TF Ops.""" 1445 1446 def handle(self, op, args, kwargs): 1447 """Handle the specified operation with the specified arguments.""" 1448 if any( 1449 isinstance(x, keras_tensor.KerasTensor) 1450 for x in nest.flatten([args, kwargs])): 1451 return TFOpLambda(op)(*args, **kwargs) 1452 else: 1453 return self.NOT_SUPPORTED 1454 1455KerasOpDispatcher().register() 1456 1457 1458def _slice_to_dict(x): 1459 if isinstance(x, slice): 1460 return {'start': x.start, 'stop': x.stop, 'step': x.step} 1461 return x 1462 1463 1464def _dict_to_slice(x): 1465 if isinstance(x, dict): 1466 return slice(x['start'], x['stop'], x['step']) 1467 return x 1468 1469 1470class SlicingOpLambda(TFOpLambda): 1471 """Wraps TF API symbols in a `Layer` object. 1472 1473 It is inserted by the Functional API construction whenever users call 1474 a supported TF symbol on KerasTensors. 1475 1476 Like Lambda layers, this layer tries to raise warnings when it detects users 1477 explicitly use variables in the call. (To let them know 1478 that the layer will not capture the variables). 1479 1480 This is useful in the case where users do something like: 1481 x = keras.Input(...) 1482 y = tf.Variable(...) 1483 out = x * tf_variable 1484 """ 1485 1486 @trackable.no_automatic_dependency_tracking 1487 def __init__(self, function, **kwargs): 1488 super(SlicingOpLambda, self).__init__(function, **kwargs) 1489 1490 original_call = self.call 1491 # Decorate the function to produce this layer's call method 1492 def _call_wrapper(*args, **kwargs): 1493 # Turn any slice dicts in the args back into `slice` objects. 1494 # This conversion cannot use nest.flatten/map_structure, 1495 # because dicts are flattened by nest while slices aren't. 1496 # So, map_structure would only see the individual elements in the 1497 # dict. 1498 # This can't use map_structure_up_to either because the 'shallowness' of 1499 # the shallow tree would have to vary depending on if only one dim or 1500 # multiple are being sliced. 1501 new_args = [] 1502 for arg in args: 1503 arg = _dict_to_slice(arg) 1504 if isinstance(arg, (list, tuple)): 1505 new_arg = [] 1506 for sub_arg in arg: 1507 new_arg.append(_dict_to_slice(sub_arg)) 1508 arg = new_arg 1509 new_args.append(arg) 1510 1511 # Handle the kwargs too. 1512 new_kwargs = {} 1513 for key, value in kwargs.items(): 1514 value = _dict_to_slice(value) 1515 if isinstance(value, (list, tuple)): 1516 new_value = [] 1517 for v in value: 1518 new_value.append(_dict_to_slice(v)) 1519 value = new_value 1520 new_kwargs[key] = value 1521 1522 return original_call(*new_args, **new_kwargs) 1523 self.call = tf_decorator.make_decorator(original_call, _call_wrapper) 1524 1525 1526class TFSlicingOpDispatcher(dispatch.OpDispatcher): 1527 """A global dispatcher that allows building a functional model with TF Ops.""" 1528 1529 def __init__(self, op): 1530 self.op = op 1531 1532 def handle(self, args, kwargs): 1533 """Handle the specified operation with the specified arguments.""" 1534 args = nest.map_structure(_slice_to_dict, args) 1535 kwargs = nest.map_structure(_slice_to_dict, kwargs) 1536 if any( 1537 isinstance(x, keras_tensor.KerasTensor) 1538 for x in nest.flatten([args, kwargs])): 1539 return SlicingOpLambda(self.op)(*args, **kwargs) 1540 else: 1541 return self.NOT_SUPPORTED 1542 1543for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access 1544 array_ops.boolean_mask, 1545 array_ops.boolean_mask_v2]: 1546 TFSlicingOpDispatcher(slicing_op).register(slicing_op) 1547 1548 1549class InstanceProperty(Layer): 1550 """Wraps an instance property access (e.g. `x.foo`) in a Keras Layer. 1551 1552 This layer takes an attribute name `attr_name` in the constructor and, 1553 when called on input tensor `obj` returns `obj.attr_name`. 1554 1555 KerasTensors specialized for specific extension types use it to 1556 represent instance property accesses on the represented object in the 1557 case where the property needs to be dynamically accessed as opposed to 1558 being statically computed from the typespec, e.g. 1559 1560 x = keras.Input(..., ragged=True) 1561 out = x.flat_values 1562 """ 1563 1564 @trackable.no_automatic_dependency_tracking 1565 def __init__(self, attr_name, **kwargs): 1566 self.attr_name = attr_name 1567 1568 if 'name' not in kwargs: 1569 kwargs['name'] = K.unique_object_name( 1570 'input.' + self.attr_name, zero_based=True, avoid_observed_names=True) 1571 kwargs['autocast'] = False 1572 1573 # Do not individually trace op layers in the SavedModel. 1574 self._must_restore_from_config = True 1575 1576 super(InstanceProperty, self).__init__(**kwargs) 1577 1578 # Preserve all argument data structures when saving/loading a config 1579 # (e.g., don't unnest lists that contain one element) 1580 self._preserve_input_structure_in_config = True 1581 1582 def call(self, obj): 1583 return getattr(obj, self.attr_name) 1584 1585 def get_config(self): 1586 config = { 1587 'attr_name': self.attr_name 1588 } 1589 base_config = super(InstanceProperty, self).get_config() 1590 return dict(list(base_config.items()) + list(config.items())) 1591 1592 @classmethod 1593 def from_config(cls, config, custom_objects=None): 1594 return cls(**config) 1595 1596 1597class InstanceMethod(InstanceProperty): 1598 """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer. 1599 1600 This layer takes an attribute name `attr_name` in the constructor and, 1601 when called on input tensor `obj` with additional arguments `args` and 1602 `kwargs` returns `obj.attr_name(*args, **kwargs)`. 1603 1604 KerasTensors specialized for specific extension types use it to 1605 represent dynamic instance method calls on the represented object, e.g. 1606 1607 x = keras.Input(..., ragged=True) 1608 new_values = keras.Input(...) 1609 out = x.with_values(new_values) 1610 """ 1611 1612 def call(self, obj, args, kwargs): 1613 method = getattr(obj, self.attr_name) 1614 return method(*args, **kwargs) 1615 1616 1617def _delegate_property(keras_tensor_cls, property_name): # pylint: disable=invalid-name 1618 """Register property on a KerasTensor class. 1619 1620 Calling this multiple times with the same arguments should be a no-op. 1621 1622 This method exposes a property on the KerasTensor class that will use an 1623 `InstanceProperty` layer to access the property on the represented 1624 intermediate values in the model. 1625 1626 Args: 1627 keras_tensor_cls: The KerasTensor subclass that should expose the property. 1628 property_name: The name of the property to expose and delegate to the 1629 represented (Composite)Tensor. 1630 """ 1631 # We use a lambda because we can't create a Keras layer at import time 1632 # due to dynamic layer class versioning. 1633 property_access = property(lambda self: InstanceProperty(property_name)(self)) # pylint: disable=unnecessary-lambda 1634 setattr(keras_tensor_cls, property_name, property_access) 1635 1636 1637def _delegate_method(keras_tensor_cls, method_name): # pylint: disable=invalid-name 1638 """Register method on a KerasTensor class. 1639 1640 Calling this function times with the same arguments should be a no-op. 1641 1642 This method exposes an instance method on the KerasTensor class that will use 1643 an `InstanceMethod` layer to run the desired method on the represented 1644 intermediate values in the model. 1645 1646 Args: 1647 keras_tensor_cls: The KerasTensor subclass that should expose the property. 1648 method_name: The name of the method to expose and delegate to the 1649 represented (Composite)Tensor. 1650 """ 1651 def delegate(self, *args, **kwargs): 1652 return InstanceMethod(method_name)(self, args, kwargs) 1653 setattr(keras_tensor_cls, method_name, delegate) 1654 1655# We do not support the `uniform_row_length` property because it 1656# returns either `None` or an int tensor, and code that relies on it tends 1657# to check `is None` directly. Delegating it here would always return a 1658# `KerasTensor`, regardless of what can be statically inferred. This would 1659# never equal `None`, breaking code that expects it to be partially-static 1660# in unpredictable ways. 1661for ragged_property in [ 1662 'values', 1663 'flat_values', 1664 'row_splits', 1665 'nested_row_splits' 1666]: 1667 _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property) 1668 1669for ragged_method_name in [ 1670 'value_rowids', 1671 'nested_value_rowids', 1672 'nrows', 1673 'row_starts', 1674 'row_limits', 1675 'row_lengths', 1676 'nested_row_lengths', 1677 'bounding_shape', 1678 'with_values', 1679 'with_flat_values', 1680 'with_row_splits_dtype', 1681 'merge_dims', 1682 'to_tensor', 1683 'to_sparse', 1684]: 1685 _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name) 1686 1687for sparse_property in [ 1688 'indices', 1689 'values', 1690]: 1691 _delegate_property(keras_tensor.SparseKerasTensor, sparse_property) 1692 1693for sparse_method in [ 1694 'with_values', 1695]: 1696 _delegate_method(keras_tensor.SparseKerasTensor, sparse_method) 1697 1698 1699class ClassMethod(Layer): 1700 """Wraps a TF API Class's class method in a `Layer` object. 1701 1702 It is inserted by the Functional API construction whenever users call 1703 a supported TF Class's class method on KerasTensors. 1704 1705 This is useful in the case where users do something like: 1706 x = keras.Input(...) 1707 y = keras.Input(...) 1708 out = tf.RaggedTensor.from_row_splits(x, y) 1709 """ 1710 1711 @trackable.no_automatic_dependency_tracking 1712 def __init__(self, cls_ref, method_name, **kwargs): 1713 self.cls_ref = cls_ref 1714 self.method_name = method_name 1715 self.cls_symbol = ( 1716 get_canonical_name_for_symbol(self.cls_ref, 1717 add_prefix_to_v1_names=True) or 1718 get_canonical_name_for_symbol(self.cls_ref, 1719 api_name='keras', 1720 add_prefix_to_v1_names=True)) 1721 if 'name' not in kwargs: 1722 kwargs['name'] = K.unique_object_name( 1723 'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True, 1724 avoid_observed_names=True) 1725 kwargs['autocast'] = False 1726 1727 # Do not individually trace op layers in the SavedModel. 1728 self._must_restore_from_config = True 1729 1730 super(ClassMethod, self).__init__(**kwargs) 1731 1732 # Preserve all argument data structures when saving/loading a config 1733 # (e.g., don't unnest lists that contain one element) 1734 self._preserve_input_structure_in_config = True 1735 1736 self._expects_training_arg = False 1737 self._expects_mask_arg = False 1738 1739 def call(self, args, kwargs): 1740 return getattr(self.cls_ref, self.method_name)(*args, **kwargs) 1741 1742 def get_config(self): 1743 if not self.cls_symbol: 1744 raise ValueError('This Keras class method conversion tried to convert ' 1745 'a method belonging to class %s, a class ' 1746 'that is not an exposed in the TensorFlow API. ' 1747 'To ensure cross-version compatibility of Keras models ' 1748 'that use op layers, only op layers produced from ' 1749 'exported TF API symbols can be serialized.' 1750 % self.cls_symbol) 1751 config = { 1752 'cls_symbol': self.cls_symbol, 1753 'method_name': self.method_name 1754 } 1755 1756 base_config = super(ClassMethod, self).get_config() 1757 return dict(list(base_config.items()) + list(config.items())) 1758 1759 @classmethod 1760 def from_config(cls, config, custom_objects=None): 1761 config = config.copy() 1762 symbol_name = config.pop('cls_symbol') 1763 cls_ref = get_symbol_from_name(symbol_name) 1764 if not cls_ref: 1765 raise ValueError( 1766 'TF symbol `tf.%s` could not be found.' % symbol_name) 1767 1768 config['cls_ref'] = cls_ref 1769 1770 return cls(**config) 1771 1772 1773class TFClassMethodDispatcher(dispatch.OpDispatcher): 1774 """A class method dispatcher that allows building a functional model with TF class methods.""" 1775 1776 def __init__(self, cls, method_name): 1777 self.cls = cls 1778 self.method_name = method_name 1779 1780 def handle(self, args, kwargs): 1781 """Handle the specified operation with the specified arguments.""" 1782 if any( 1783 isinstance(x, keras_tensor.KerasTensor) 1784 for x in nest.flatten([args, kwargs])): 1785 return ClassMethod(self.cls, self.method_name)(args[1:], kwargs) 1786 else: 1787 return self.NOT_SUPPORTED 1788 1789for ragged_class_method in [ 1790 'from_value_rowids', 1791 'from_row_splits', 1792 'from_row_lengths', 1793 'from_row_starts', 1794 'from_row_limits', 1795 'from_uniform_row_length', 1796 'from_nested_value_rowids', 1797 'from_nested_row_splits', 1798 'from_nested_row_lengths', 1799 'from_tensor', 1800 'from_sparse', 1801]: 1802 TFClassMethodDispatcher( 1803 ragged_tensor.RaggedTensor, ragged_class_method).register( 1804 getattr(ragged_tensor.RaggedTensor, ragged_class_method)) 1805