1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras convolution layers and image transformation layers. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.keras import activations 25from tensorflow.python.keras import backend 26from tensorflow.python.keras import constraints 27from tensorflow.python.keras import initializers 28from tensorflow.python.keras import regularizers 29from tensorflow.python.keras.engine.base_layer import Layer 30from tensorflow.python.keras.engine.input_spec import InputSpec 31# imports for backwards namespace compatibility 32# pylint: disable=unused-import 33from tensorflow.python.keras.layers.pooling import AveragePooling1D 34from tensorflow.python.keras.layers.pooling import AveragePooling2D 35from tensorflow.python.keras.layers.pooling import AveragePooling3D 36from tensorflow.python.keras.layers.pooling import MaxPooling1D 37from tensorflow.python.keras.layers.pooling import MaxPooling2D 38from tensorflow.python.keras.layers.pooling import MaxPooling3D 39# pylint: enable=unused-import 40from tensorflow.python.keras.utils import conv_utils 41from tensorflow.python.keras.utils import tf_utils 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import nn 44from tensorflow.python.ops import nn_ops 45from tensorflow.python.util.tf_export import keras_export 46 47 48class Conv(Layer): 49 """Abstract N-D convolution layer (private, used as implementation base). 50 51 This layer creates a convolution kernel that is convolved 52 (actually cross-correlated) with the layer input to produce a tensor of 53 outputs. If `use_bias` is True (and a `bias_initializer` is provided), 54 a bias vector is created and added to the outputs. Finally, if 55 `activation` is not `None`, it is applied to the outputs as well. 56 57 Arguments: 58 rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. 59 filters: Integer, the dimensionality of the output space (i.e. the number 60 of filters in the convolution). 61 kernel_size: An integer or tuple/list of n integers, specifying the 62 length of the convolution window. 63 strides: An integer or tuple/list of n integers, 64 specifying the stride length of the convolution. 65 Specifying any stride value != 1 is incompatible with specifying 66 any `dilation_rate` value != 1. 67 padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive). 68 data_format: A string, one of `channels_last` (default) or `channels_first`. 69 The ordering of the dimensions in the inputs. 70 `channels_last` corresponds to inputs with shape 71 `(batch, ..., channels)` while `channels_first` corresponds to 72 inputs with shape `(batch, channels, ...)`. 73 dilation_rate: An integer or tuple/list of n integers, specifying 74 the dilation rate to use for dilated convolution. 75 Currently, specifying any `dilation_rate` value != 1 is 76 incompatible with specifying any `strides` value != 1. 77 activation: Activation function. Set it to None to maintain a 78 linear activation. 79 use_bias: Boolean, whether the layer uses a bias. 80 kernel_initializer: An initializer for the convolution kernel. 81 bias_initializer: An initializer for the bias vector. If None, the default 82 initializer will be used. 83 kernel_regularizer: Optional regularizer for the convolution kernel. 84 bias_regularizer: Optional regularizer for the bias vector. 85 activity_regularizer: Optional regularizer function for the output. 86 kernel_constraint: Optional projection function to be applied to the 87 kernel after being updated by an `Optimizer` (e.g. used to implement 88 norm constraints or value constraints for layer weights). The function 89 must take as input the unprojected variable and must return the 90 projected variable (which must have the same shape). Constraints are 91 not safe to use when doing asynchronous distributed training. 92 bias_constraint: Optional projection function to be applied to the 93 bias after being updated by an `Optimizer`. 94 trainable: Boolean, if `True` the weights of this layer will be marked as 95 trainable (and listed in `layer.trainable_weights`). 96 name: A string, the name of the layer. 97 """ 98 99 def __init__(self, rank, 100 filters, 101 kernel_size, 102 strides=1, 103 padding='valid', 104 data_format=None, 105 dilation_rate=1, 106 activation=None, 107 use_bias=True, 108 kernel_initializer='glorot_uniform', 109 bias_initializer='zeros', 110 kernel_regularizer=None, 111 bias_regularizer=None, 112 activity_regularizer=None, 113 kernel_constraint=None, 114 bias_constraint=None, 115 trainable=True, 116 name=None, 117 **kwargs): 118 super(Conv, self).__init__( 119 trainable=trainable, 120 name=name, 121 activity_regularizer=regularizers.get(activity_regularizer), 122 **kwargs) 123 self.rank = rank 124 self.filters = filters 125 self.kernel_size = conv_utils.normalize_tuple( 126 kernel_size, rank, 'kernel_size') 127 self.strides = conv_utils.normalize_tuple(strides, rank, 'strides') 128 self.padding = conv_utils.normalize_padding(padding) 129 if (self.padding == 'causal' and not isinstance(self, 130 (Conv1D, SeparableConv1D))): 131 raise ValueError('Causal padding is only supported for `Conv1D`' 132 'and ``SeparableConv1D`.') 133 self.data_format = conv_utils.normalize_data_format(data_format) 134 self.dilation_rate = conv_utils.normalize_tuple( 135 dilation_rate, rank, 'dilation_rate') 136 self.activation = activations.get(activation) 137 self.use_bias = use_bias 138 self.kernel_initializer = initializers.get(kernel_initializer) 139 self.bias_initializer = initializers.get(bias_initializer) 140 self.kernel_regularizer = regularizers.get(kernel_regularizer) 141 self.bias_regularizer = regularizers.get(bias_regularizer) 142 self.kernel_constraint = constraints.get(kernel_constraint) 143 self.bias_constraint = constraints.get(bias_constraint) 144 self.input_spec = InputSpec(ndim=self.rank + 2) 145 146 def build(self, input_shape): 147 input_shape = tensor_shape.TensorShape(input_shape) 148 if self.data_format == 'channels_first': 149 channel_axis = 1 150 else: 151 channel_axis = -1 152 if input_shape.dims[channel_axis].value is None: 153 raise ValueError('The channel dimension of the inputs ' 154 'should be defined. Found `None`.') 155 input_dim = int(input_shape[channel_axis]) 156 kernel_shape = self.kernel_size + (input_dim, self.filters) 157 158 self.kernel = self.add_weight( 159 name='kernel', 160 shape=kernel_shape, 161 initializer=self.kernel_initializer, 162 regularizer=self.kernel_regularizer, 163 constraint=self.kernel_constraint, 164 trainable=True, 165 dtype=self.dtype) 166 if self.use_bias: 167 self.bias = self.add_weight( 168 name='bias', 169 shape=(self.filters,), 170 initializer=self.bias_initializer, 171 regularizer=self.bias_regularizer, 172 constraint=self.bias_constraint, 173 trainable=True, 174 dtype=self.dtype) 175 else: 176 self.bias = None 177 self.input_spec = InputSpec(ndim=self.rank + 2, 178 axes={channel_axis: input_dim}) 179 if self.padding == 'causal': 180 op_padding = 'valid' 181 else: 182 op_padding = self.padding 183 if not isinstance(op_padding, (list, tuple)): 184 op_padding = op_padding.upper() 185 self._convolution_op = nn_ops.Convolution( 186 input_shape, 187 filter_shape=self.kernel.get_shape(), 188 dilation_rate=self.dilation_rate, 189 strides=self.strides, 190 padding=op_padding, 191 data_format=conv_utils.convert_data_format(self.data_format, 192 self.rank + 2)) 193 self.built = True 194 195 def call(self, inputs): 196 outputs = self._convolution_op(inputs, self.kernel) 197 198 if self.use_bias: 199 if self.data_format == 'channels_first': 200 if self.rank == 1: 201 # nn.bias_add does not accept a 1D input tensor. 202 bias = array_ops.reshape(self.bias, (1, self.filters, 1)) 203 outputs += bias 204 else: 205 outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') 206 else: 207 outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') 208 209 if self.activation is not None: 210 return self.activation(outputs) 211 return outputs 212 213 def compute_output_shape(self, input_shape): 214 input_shape = tensor_shape.TensorShape(input_shape).as_list() 215 if self.data_format == 'channels_last': 216 space = input_shape[1:-1] 217 new_space = [] 218 for i in range(len(space)): 219 new_dim = conv_utils.conv_output_length( 220 space[i], 221 self.kernel_size[i], 222 padding=self.padding, 223 stride=self.strides[i], 224 dilation=self.dilation_rate[i]) 225 new_space.append(new_dim) 226 return tensor_shape.TensorShape([input_shape[0]] + new_space + 227 [self.filters]) 228 else: 229 space = input_shape[2:] 230 new_space = [] 231 for i in range(len(space)): 232 new_dim = conv_utils.conv_output_length( 233 space[i], 234 self.kernel_size[i], 235 padding=self.padding, 236 stride=self.strides[i], 237 dilation=self.dilation_rate[i]) 238 new_space.append(new_dim) 239 return tensor_shape.TensorShape([input_shape[0], self.filters] + 240 new_space) 241 242 def get_config(self): 243 config = { 244 'filters': self.filters, 245 'kernel_size': self.kernel_size, 246 'strides': self.strides, 247 'padding': self.padding, 248 'data_format': self.data_format, 249 'dilation_rate': self.dilation_rate, 250 'activation': activations.serialize(self.activation), 251 'use_bias': self.use_bias, 252 'kernel_initializer': initializers.serialize(self.kernel_initializer), 253 'bias_initializer': initializers.serialize(self.bias_initializer), 254 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 255 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 256 'activity_regularizer': 257 regularizers.serialize(self.activity_regularizer), 258 'kernel_constraint': constraints.serialize(self.kernel_constraint), 259 'bias_constraint': constraints.serialize(self.bias_constraint) 260 } 261 base_config = super(Conv, self).get_config() 262 return dict(list(base_config.items()) + list(config.items())) 263 264 def _compute_causal_padding(self): 265 """Calculates padding for 'causal' option for 1-d conv layers.""" 266 left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1) 267 if self.data_format == 'channels_last': 268 causal_padding = [[0, 0], [left_pad, 0], [0, 0]] 269 else: 270 causal_padding = [[0, 0], [0, 0], [left_pad, 0]] 271 return causal_padding 272 273 274@keras_export('keras.layers.Conv1D', 'keras.layers.Convolution1D') 275class Conv1D(Conv): 276 """1D convolution layer (e.g. temporal convolution). 277 278 This layer creates a convolution kernel that is convolved 279 with the layer input over a single spatial (or temporal) dimension 280 to produce a tensor of outputs. 281 If `use_bias` is True, a bias vector is created and added to the outputs. 282 Finally, if `activation` is not `None`, 283 it is applied to the outputs as well. 284 285 When using this layer as the first layer in a model, 286 provide an `input_shape` argument 287 (tuple of integers or `None`, e.g. 288 `(10, 128)` for sequences of 10 vectors of 128-dimensional vectors, 289 or `(None, 128)` for variable-length sequences of 128-dimensional vectors. 290 291 Arguments: 292 filters: Integer, the dimensionality of the output space 293 (i.e. the number of output filters in the convolution). 294 kernel_size: An integer or tuple/list of a single integer, 295 specifying the length of the 1D convolution window. 296 strides: An integer or tuple/list of a single integer, 297 specifying the stride length of the convolution. 298 Specifying any stride value != 1 is incompatible with specifying 299 any `dilation_rate` value != 1. 300 padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive). 301 `"causal"` results in causal (dilated) convolutions, e.g. output[t] 302 does not depend on input[t+1:]. Useful when modeling temporal data 303 where the model should not violate the temporal order. 304 See [WaveNet: A Generative Model for Raw Audio, section 305 2.1](https://arxiv.org/abs/1609.03499). 306 data_format: A string, 307 one of `channels_last` (default) or `channels_first`. 308 dilation_rate: an integer or tuple/list of a single integer, specifying 309 the dilation rate to use for dilated convolution. 310 Currently, specifying any `dilation_rate` value != 1 is 311 incompatible with specifying any `strides` value != 1. 312 activation: Activation function to use. 313 If you don't specify anything, no activation is applied 314 (ie. "linear" activation: `a(x) = x`). 315 use_bias: Boolean, whether the layer uses a bias vector. 316 kernel_initializer: Initializer for the `kernel` weights matrix. 317 bias_initializer: Initializer for the bias vector. 318 kernel_regularizer: Regularizer function applied to 319 the `kernel` weights matrix. 320 bias_regularizer: Regularizer function applied to the bias vector. 321 activity_regularizer: Regularizer function applied to 322 the output of the layer (its "activation").. 323 kernel_constraint: Constraint function applied to the kernel matrix. 324 bias_constraint: Constraint function applied to the bias vector. 325 326 Input shape: 327 3D tensor with shape: `(batch_size, steps, input_dim)` 328 329 Output shape: 330 3D tensor with shape: `(batch_size, new_steps, filters)` 331 `steps` value might have changed due to padding or strides. 332 """ 333 334 def __init__(self, 335 filters, 336 kernel_size, 337 strides=1, 338 padding='valid', 339 data_format='channels_last', 340 dilation_rate=1, 341 activation=None, 342 use_bias=True, 343 kernel_initializer='glorot_uniform', 344 bias_initializer='zeros', 345 kernel_regularizer=None, 346 bias_regularizer=None, 347 activity_regularizer=None, 348 kernel_constraint=None, 349 bias_constraint=None, 350 **kwargs): 351 super(Conv1D, self).__init__( 352 rank=1, 353 filters=filters, 354 kernel_size=kernel_size, 355 strides=strides, 356 padding=padding, 357 data_format=data_format, 358 dilation_rate=dilation_rate, 359 activation=activations.get(activation), 360 use_bias=use_bias, 361 kernel_initializer=initializers.get(kernel_initializer), 362 bias_initializer=initializers.get(bias_initializer), 363 kernel_regularizer=regularizers.get(kernel_regularizer), 364 bias_regularizer=regularizers.get(bias_regularizer), 365 activity_regularizer=regularizers.get(activity_regularizer), 366 kernel_constraint=constraints.get(kernel_constraint), 367 bias_constraint=constraints.get(bias_constraint), 368 **kwargs) 369 370 def call(self, inputs): 371 if self.padding == 'causal': 372 inputs = array_ops.pad(inputs, self._compute_causal_padding()) 373 return super(Conv1D, self).call(inputs) 374 375 376@keras_export('keras.layers.Conv2D', 'keras.layers.Convolution2D') 377class Conv2D(Conv): 378 """2D convolution layer (e.g. spatial convolution over images). 379 380 This layer creates a convolution kernel that is convolved 381 with the layer input to produce a tensor of 382 outputs. If `use_bias` is True, 383 a bias vector is created and added to the outputs. Finally, if 384 `activation` is not `None`, it is applied to the outputs as well. 385 386 When using this layer as the first layer in a model, 387 provide the keyword argument `input_shape` 388 (tuple of integers, does not include the sample axis), 389 e.g. `input_shape=(128, 128, 3)` for 128x128 RGB pictures 390 in `data_format="channels_last"`. 391 392 Arguments: 393 filters: Integer, the dimensionality of the output space 394 (i.e. the number of output filters in the convolution). 395 kernel_size: An integer or tuple/list of 2 integers, specifying the 396 height and width of the 2D convolution window. 397 Can be a single integer to specify the same value for 398 all spatial dimensions. 399 strides: An integer or tuple/list of 2 integers, 400 specifying the strides of the convolution along the height and width. 401 Can be a single integer to specify the same value for 402 all spatial dimensions. 403 Specifying any stride value != 1 is incompatible with specifying 404 any `dilation_rate` value != 1. 405 padding: one of `"valid"` or `"same"` (case-insensitive). 406 data_format: A string, 407 one of `channels_last` (default) or `channels_first`. 408 The ordering of the dimensions in the inputs. 409 `channels_last` corresponds to inputs with shape 410 `(batch, height, width, channels)` while `channels_first` 411 corresponds to inputs with shape 412 `(batch, channels, height, width)`. 413 It defaults to the `image_data_format` value found in your 414 Keras config file at `~/.keras/keras.json`. 415 If you never set it, then it will be "channels_last". 416 dilation_rate: an integer or tuple/list of 2 integers, specifying 417 the dilation rate to use for dilated convolution. 418 Can be a single integer to specify the same value for 419 all spatial dimensions. 420 Currently, specifying any `dilation_rate` value != 1 is 421 incompatible with specifying any stride value != 1. 422 activation: Activation function to use. 423 If you don't specify anything, no activation is applied 424 (ie. "linear" activation: `a(x) = x`). 425 use_bias: Boolean, whether the layer uses a bias vector. 426 kernel_initializer: Initializer for the `kernel` weights matrix. 427 bias_initializer: Initializer for the bias vector. 428 kernel_regularizer: Regularizer function applied to 429 the `kernel` weights matrix. 430 bias_regularizer: Regularizer function applied to the bias vector. 431 activity_regularizer: Regularizer function applied to 432 the output of the layer (its "activation").. 433 kernel_constraint: Constraint function applied to the kernel matrix. 434 bias_constraint: Constraint function applied to the bias vector. 435 436 Input shape: 437 4D tensor with shape: 438 `(samples, channels, rows, cols)` if data_format='channels_first' 439 or 4D tensor with shape: 440 `(samples, rows, cols, channels)` if data_format='channels_last'. 441 442 Output shape: 443 4D tensor with shape: 444 `(samples, filters, new_rows, new_cols)` if data_format='channels_first' 445 or 4D tensor with shape: 446 `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. 447 `rows` and `cols` values might have changed due to padding. 448 """ 449 450 def __init__(self, 451 filters, 452 kernel_size, 453 strides=(1, 1), 454 padding='valid', 455 data_format=None, 456 dilation_rate=(1, 1), 457 activation=None, 458 use_bias=True, 459 kernel_initializer='glorot_uniform', 460 bias_initializer='zeros', 461 kernel_regularizer=None, 462 bias_regularizer=None, 463 activity_regularizer=None, 464 kernel_constraint=None, 465 bias_constraint=None, 466 **kwargs): 467 super(Conv2D, self).__init__( 468 rank=2, 469 filters=filters, 470 kernel_size=kernel_size, 471 strides=strides, 472 padding=padding, 473 data_format=data_format, 474 dilation_rate=dilation_rate, 475 activation=activations.get(activation), 476 use_bias=use_bias, 477 kernel_initializer=initializers.get(kernel_initializer), 478 bias_initializer=initializers.get(bias_initializer), 479 kernel_regularizer=regularizers.get(kernel_regularizer), 480 bias_regularizer=regularizers.get(bias_regularizer), 481 activity_regularizer=regularizers.get(activity_regularizer), 482 kernel_constraint=constraints.get(kernel_constraint), 483 bias_constraint=constraints.get(bias_constraint), 484 **kwargs) 485 486 487@keras_export('keras.layers.Conv3D', 'keras.layers.Convolution3D') 488class Conv3D(Conv): 489 """3D convolution layer (e.g. spatial convolution over volumes). 490 491 This layer creates a convolution kernel that is convolved 492 with the layer input to produce a tensor of 493 outputs. If `use_bias` is True, 494 a bias vector is created and added to the outputs. Finally, if 495 `activation` is not `None`, it is applied to the outputs as well. 496 497 When using this layer as the first layer in a model, 498 provide the keyword argument `input_shape` 499 (tuple of integers, does not include the sample axis), 500 e.g. `input_shape=(128, 128, 128, 1)` for 128x128x128 volumes 501 with a single channel, 502 in `data_format="channels_last"`. 503 504 Arguments: 505 filters: Integer, the dimensionality of the output space 506 (i.e. the number of output filters in the convolution). 507 kernel_size: An integer or tuple/list of 3 integers, specifying the 508 depth, height and width of the 3D convolution window. 509 Can be a single integer to specify the same value for 510 all spatial dimensions. 511 strides: An integer or tuple/list of 3 integers, 512 specifying the strides of the convolution along each spatial 513 dimension. 514 Can be a single integer to specify the same value for 515 all spatial dimensions. 516 Specifying any stride value != 1 is incompatible with specifying 517 any `dilation_rate` value != 1. 518 padding: one of `"valid"` or `"same"` (case-insensitive). 519 data_format: A string, 520 one of `channels_last` (default) or `channels_first`. 521 The ordering of the dimensions in the inputs. 522 `channels_last` corresponds to inputs with shape 523 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 524 while `channels_first` corresponds to inputs with shape 525 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 526 It defaults to the `image_data_format` value found in your 527 Keras config file at `~/.keras/keras.json`. 528 If you never set it, then it will be "channels_last". 529 dilation_rate: an integer or tuple/list of 3 integers, specifying 530 the dilation rate to use for dilated convolution. 531 Can be a single integer to specify the same value for 532 all spatial dimensions. 533 Currently, specifying any `dilation_rate` value != 1 is 534 incompatible with specifying any stride value != 1. 535 activation: Activation function to use. 536 If you don't specify anything, no activation is applied 537 (ie. "linear" activation: `a(x) = x`). 538 use_bias: Boolean, whether the layer uses a bias vector. 539 kernel_initializer: Initializer for the `kernel` weights matrix. 540 bias_initializer: Initializer for the bias vector. 541 kernel_regularizer: Regularizer function applied to 542 the `kernel` weights matrix. 543 bias_regularizer: Regularizer function applied to the bias vector. 544 activity_regularizer: Regularizer function applied to 545 the output of the layer (its "activation").. 546 kernel_constraint: Constraint function applied to the kernel matrix. 547 bias_constraint: Constraint function applied to the bias vector. 548 549 Input shape: 550 5D tensor with shape: 551 `(samples, channels, conv_dim1, conv_dim2, conv_dim3)` if 552 data_format='channels_first' 553 or 5D tensor with shape: 554 `(samples, conv_dim1, conv_dim2, conv_dim3, channels)` if 555 data_format='channels_last'. 556 557 Output shape: 558 5D tensor with shape: 559 `(samples, filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` if 560 data_format='channels_first' 561 or 5D tensor with shape: 562 `(samples, new_conv_dim1, new_conv_dim2, new_conv_dim3, filters)` if 563 data_format='channels_last'. 564 `new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have 565 changed due to padding. 566 """ 567 568 def __init__(self, 569 filters, 570 kernel_size, 571 strides=(1, 1, 1), 572 padding='valid', 573 data_format=None, 574 dilation_rate=(1, 1, 1), 575 activation=None, 576 use_bias=True, 577 kernel_initializer='glorot_uniform', 578 bias_initializer='zeros', 579 kernel_regularizer=None, 580 bias_regularizer=None, 581 activity_regularizer=None, 582 kernel_constraint=None, 583 bias_constraint=None, 584 **kwargs): 585 super(Conv3D, self).__init__( 586 rank=3, 587 filters=filters, 588 kernel_size=kernel_size, 589 strides=strides, 590 padding=padding, 591 data_format=data_format, 592 dilation_rate=dilation_rate, 593 activation=activations.get(activation), 594 use_bias=use_bias, 595 kernel_initializer=initializers.get(kernel_initializer), 596 bias_initializer=initializers.get(bias_initializer), 597 kernel_regularizer=regularizers.get(kernel_regularizer), 598 bias_regularizer=regularizers.get(bias_regularizer), 599 activity_regularizer=regularizers.get(activity_regularizer), 600 kernel_constraint=constraints.get(kernel_constraint), 601 bias_constraint=constraints.get(bias_constraint), 602 **kwargs) 603 604 605@keras_export('keras.layers.Conv2DTranspose', 606 'keras.layers.Convolution2DTranspose') 607class Conv2DTranspose(Conv2D): 608 """Transposed convolution layer (sometimes called Deconvolution). 609 610 The need for transposed convolutions generally arises 611 from the desire to use a transformation going in the opposite direction 612 of a normal convolution, i.e., from something that has the shape of the 613 output of some convolution to something that has the shape of its input 614 while maintaining a connectivity pattern that is compatible with 615 said convolution. 616 617 When using this layer as the first layer in a model, 618 provide the keyword argument `input_shape` 619 (tuple of integers, does not include the sample axis), 620 e.g. `input_shape=(128, 128, 3)` for 128x128 RGB pictures 621 in `data_format="channels_last"`. 622 623 Arguments: 624 filters: Integer, the dimensionality of the output space 625 (i.e. the number of output filters in the convolution). 626 kernel_size: An integer or tuple/list of 2 integers, specifying the 627 height and width of the 2D convolution window. 628 Can be a single integer to specify the same value for 629 all spatial dimensions. 630 strides: An integer or tuple/list of 2 integers, 631 specifying the strides of the convolution along the height and width. 632 Can be a single integer to specify the same value for 633 all spatial dimensions. 634 Specifying any stride value != 1 is incompatible with specifying 635 any `dilation_rate` value != 1. 636 padding: one of `"valid"` or `"same"` (case-insensitive). 637 output_padding: An integer or tuple/list of 2 integers, 638 specifying the amount of padding along the height and width 639 of the output tensor. 640 Can be a single integer to specify the same value for all 641 spatial dimensions. 642 The amount of output padding along a given dimension must be 643 lower than the stride along that same dimension. 644 If set to `None` (default), the output shape is inferred. 645 data_format: A string, 646 one of `channels_last` (default) or `channels_first`. 647 The ordering of the dimensions in the inputs. 648 `channels_last` corresponds to inputs with shape 649 `(batch, height, width, channels)` while `channels_first` 650 corresponds to inputs with shape 651 `(batch, channels, height, width)`. 652 It defaults to the `image_data_format` value found in your 653 Keras config file at `~/.keras/keras.json`. 654 If you never set it, then it will be "channels_last". 655 dilation_rate: an integer or tuple/list of 2 integers, specifying 656 the dilation rate to use for dilated convolution. 657 Can be a single integer to specify the same value for 658 all spatial dimensions. 659 Currently, specifying any `dilation_rate` value != 1 is 660 incompatible with specifying any stride value != 1. 661 activation: Activation function to use. 662 If you don't specify anything, no activation is applied 663 (ie. "linear" activation: `a(x) = x`). 664 use_bias: Boolean, whether the layer uses a bias vector. 665 kernel_initializer: Initializer for the `kernel` weights matrix. 666 bias_initializer: Initializer for the bias vector. 667 kernel_regularizer: Regularizer function applied to 668 the `kernel` weights matrix. 669 bias_regularizer: Regularizer function applied to the bias vector. 670 activity_regularizer: Regularizer function applied to 671 the output of the layer (its "activation").. 672 kernel_constraint: Constraint function applied to the kernel matrix. 673 bias_constraint: Constraint function applied to the bias vector. 674 675 Input shape: 676 4D tensor with shape: 677 `(batch, channels, rows, cols)` if data_format='channels_first' 678 or 4D tensor with shape: 679 `(batch, rows, cols, channels)` if data_format='channels_last'. 680 681 Output shape: 682 4D tensor with shape: 683 `(batch, filters, new_rows, new_cols)` if data_format='channels_first' 684 or 4D tensor with shape: 685 `(batch, new_rows, new_cols, filters)` if data_format='channels_last'. 686 `rows` and `cols` values might have changed due to padding. 687 688 References: 689 - [A guide to convolution arithmetic for deep 690 learning](https://arxiv.org/abs/1603.07285v1) 691 - [Deconvolutional 692 Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf) 693 """ 694 695 def __init__(self, 696 filters, 697 kernel_size, 698 strides=(1, 1), 699 padding='valid', 700 output_padding=None, 701 data_format=None, 702 dilation_rate=(1, 1), 703 activation=None, 704 use_bias=True, 705 kernel_initializer='glorot_uniform', 706 bias_initializer='zeros', 707 kernel_regularizer=None, 708 bias_regularizer=None, 709 activity_regularizer=None, 710 kernel_constraint=None, 711 bias_constraint=None, 712 **kwargs): 713 super(Conv2DTranspose, self).__init__( 714 filters=filters, 715 kernel_size=kernel_size, 716 strides=strides, 717 padding=padding, 718 data_format=data_format, 719 dilation_rate=dilation_rate, 720 activation=activations.get(activation), 721 use_bias=use_bias, 722 kernel_initializer=initializers.get(kernel_initializer), 723 bias_initializer=initializers.get(bias_initializer), 724 kernel_regularizer=regularizers.get(kernel_regularizer), 725 bias_regularizer=regularizers.get(bias_regularizer), 726 activity_regularizer=regularizers.get(activity_regularizer), 727 kernel_constraint=constraints.get(kernel_constraint), 728 bias_constraint=constraints.get(bias_constraint), 729 **kwargs) 730 731 self.output_padding = output_padding 732 if self.output_padding is not None: 733 self.output_padding = conv_utils.normalize_tuple( 734 self.output_padding, 2, 'output_padding') 735 for stride, out_pad in zip(self.strides, self.output_padding): 736 if out_pad >= stride: 737 raise ValueError('Stride ' + str(self.strides) + ' must be ' 738 'greater than output padding ' + 739 str(self.output_padding)) 740 741 def build(self, input_shape): 742 input_shape = tensor_shape.TensorShape(input_shape) 743 if len(input_shape) != 4: 744 raise ValueError('Inputs should have rank 4. Received input shape: ' + 745 str(input_shape)) 746 if self.data_format == 'channels_first': 747 channel_axis = 1 748 else: 749 channel_axis = -1 750 if input_shape.dims[channel_axis].value is None: 751 raise ValueError('The channel dimension of the inputs ' 752 'should be defined. Found `None`.') 753 input_dim = int(input_shape[channel_axis]) 754 self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) 755 kernel_shape = self.kernel_size + (self.filters, input_dim) 756 757 self.kernel = self.add_weight( 758 name='kernel', 759 shape=kernel_shape, 760 initializer=self.kernel_initializer, 761 regularizer=self.kernel_regularizer, 762 constraint=self.kernel_constraint, 763 trainable=True, 764 dtype=self.dtype) 765 if self.use_bias: 766 self.bias = self.add_weight( 767 name='bias', 768 shape=(self.filters,), 769 initializer=self.bias_initializer, 770 regularizer=self.bias_regularizer, 771 constraint=self.bias_constraint, 772 trainable=True, 773 dtype=self.dtype) 774 else: 775 self.bias = None 776 self.built = True 777 778 def call(self, inputs): 779 inputs_shape = array_ops.shape(inputs) 780 batch_size = inputs_shape[0] 781 if self.data_format == 'channels_first': 782 h_axis, w_axis = 2, 3 783 else: 784 h_axis, w_axis = 1, 2 785 786 height, width = inputs_shape[h_axis], inputs_shape[w_axis] 787 kernel_h, kernel_w = self.kernel_size 788 stride_h, stride_w = self.strides 789 790 if self.output_padding is None: 791 out_pad_h = out_pad_w = None 792 else: 793 out_pad_h, out_pad_w = self.output_padding 794 795 # Infer the dynamic output shape: 796 out_height = conv_utils.deconv_output_length(height, 797 kernel_h, 798 padding=self.padding, 799 output_padding=out_pad_h, 800 stride=stride_h, 801 dilation=self.dilation_rate[0]) 802 out_width = conv_utils.deconv_output_length(width, 803 kernel_w, 804 padding=self.padding, 805 output_padding=out_pad_w, 806 stride=stride_w, 807 dilation=self.dilation_rate[1]) 808 if self.data_format == 'channels_first': 809 output_shape = (batch_size, self.filters, out_height, out_width) 810 else: 811 output_shape = (batch_size, out_height, out_width, self.filters) 812 813 output_shape_tensor = array_ops.stack(output_shape) 814 outputs = backend.conv2d_transpose( 815 inputs, 816 self.kernel, 817 output_shape_tensor, 818 strides=self.strides, 819 padding=self.padding, 820 data_format=self.data_format, 821 dilation_rate=self.dilation_rate) 822 823 if not context.executing_eagerly(): 824 # Infer the static output shape: 825 out_shape = self.compute_output_shape(inputs.shape) 826 outputs.set_shape(out_shape) 827 828 if self.use_bias: 829 outputs = nn.bias_add( 830 outputs, 831 self.bias, 832 data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) 833 834 if self.activation is not None: 835 return self.activation(outputs) 836 return outputs 837 838 def compute_output_shape(self, input_shape): 839 input_shape = tensor_shape.TensorShape(input_shape).as_list() 840 output_shape = list(input_shape) 841 if self.data_format == 'channels_first': 842 c_axis, h_axis, w_axis = 1, 2, 3 843 else: 844 c_axis, h_axis, w_axis = 3, 1, 2 845 846 kernel_h, kernel_w = self.kernel_size 847 stride_h, stride_w = self.strides 848 849 if self.output_padding is None: 850 out_pad_h = out_pad_w = None 851 else: 852 out_pad_h, out_pad_w = self.output_padding 853 854 output_shape[c_axis] = self.filters 855 output_shape[h_axis] = conv_utils.deconv_output_length( 856 output_shape[h_axis], 857 kernel_h, 858 padding=self.padding, 859 output_padding=out_pad_h, 860 stride=stride_h, 861 dilation=self.dilation_rate[0]) 862 output_shape[w_axis] = conv_utils.deconv_output_length( 863 output_shape[w_axis], 864 kernel_w, 865 padding=self.padding, 866 output_padding=out_pad_w, 867 stride=stride_w, 868 dilation=self.dilation_rate[1]) 869 return tensor_shape.TensorShape(output_shape) 870 871 def get_config(self): 872 config = super(Conv2DTranspose, self).get_config() 873 config['output_padding'] = self.output_padding 874 return config 875 876 877@keras_export('keras.layers.Conv3DTranspose', 878 'keras.layers.Convolution3DTranspose') 879class Conv3DTranspose(Conv3D): 880 """Transposed convolution layer (sometimes called Deconvolution). 881 882 The need for transposed convolutions generally arises 883 from the desire to use a transformation going in the opposite direction 884 of a normal convolution, i.e., from something that has the shape of the 885 output of some convolution to something that has the shape of its input 886 while maintaining a connectivity pattern that is compatible with 887 said convolution. 888 889 When using this layer as the first layer in a model, 890 provide the keyword argument `input_shape` 891 (tuple of integers, does not include the sample axis), 892 e.g. `input_shape=(128, 128, 128, 3)` for a 128x128x128 volume with 3 channels 893 if `data_format="channels_last"`. 894 895 Arguments: 896 filters: Integer, the dimensionality of the output space 897 (i.e. the number of output filters in the convolution). 898 kernel_size: An integer or tuple/list of 3 integers, specifying the 899 depth, height and width of the 3D convolution window. 900 Can be a single integer to specify the same value for 901 all spatial dimensions. 902 strides: An integer or tuple/list of 3 integers, 903 specifying the strides of the convolution along the depth, height 904 and width. 905 Can be a single integer to specify the same value for 906 all spatial dimensions. 907 Specifying any stride value != 1 is incompatible with specifying 908 any `dilation_rate` value != 1. 909 padding: one of `"valid"` or `"same"` (case-insensitive). 910 output_padding: An integer or tuple/list of 3 integers, 911 specifying the amount of padding along the depth, height, and 912 width. 913 Can be a single integer to specify the same value for all 914 spatial dimensions. 915 The amount of output padding along a given dimension must be 916 lower than the stride along that same dimension. 917 If set to `None` (default), the output shape is inferred. 918 data_format: A string, 919 one of `channels_last` (default) or `channels_first`. 920 The ordering of the dimensions in the inputs. 921 `channels_last` corresponds to inputs with shape 922 `(batch, depth, height, width, channels)` while `channels_first` 923 corresponds to inputs with shape 924 `(batch, channels, depth, height, width)`. 925 It defaults to the `image_data_format` value found in your 926 Keras config file at `~/.keras/keras.json`. 927 If you never set it, then it will be "channels_last". 928 dilation_rate: an integer or tuple/list of 3 integers, specifying 929 the dilation rate to use for dilated convolution. 930 Can be a single integer to specify the same value for 931 all spatial dimensions. 932 Currently, specifying any `dilation_rate` value != 1 is 933 incompatible with specifying any stride value != 1. 934 activation: Activation function to use. 935 If you don't specify anything, no activation is applied 936 (ie. "linear" activation: `a(x) = x`). 937 use_bias: Boolean, whether the layer uses a bias vector. 938 kernel_initializer: Initializer for the `kernel` weights matrix. 939 bias_initializer: Initializer for the bias vector. 940 kernel_regularizer: Regularizer function applied to 941 the `kernel` weights matrix. 942 bias_regularizer: Regularizer function applied to the bias vector. 943 activity_regularizer: Regularizer function applied to 944 the output of the layer (its "activation"). 945 kernel_constraint: Constraint function applied to the kernel matrix. 946 bias_constraint: Constraint function applied to the bias vector. 947 948 Input shape: 949 5D tensor with shape: 950 `(batch, channels, depth, rows, cols)` if data_format='channels_first' 951 or 5D tensor with shape: 952 `(batch, depth, rows, cols, channels)` if data_format='channels_last'. 953 954 Output shape: 955 5D tensor with shape: 956 `(batch, filters, new_depth, new_rows, new_cols)` if 957 data_format='channels_first' 958 or 5D tensor with shape: 959 `(batch, new_depth, new_rows, new_cols, filters)` if 960 data_format='channels_last'. 961 `depth` and `rows` and `cols` values might have changed due to padding. 962 963 References: 964 - [A guide to convolution arithmetic for deep 965 learning](https://arxiv.org/abs/1603.07285v1) 966 - [Deconvolutional 967 Networks](https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf) 968 """ 969 970 def __init__(self, 971 filters, 972 kernel_size, 973 strides=(1, 1, 1), 974 padding='valid', 975 output_padding=None, 976 data_format=None, 977 activation=None, 978 use_bias=True, 979 kernel_initializer='glorot_uniform', 980 bias_initializer='zeros', 981 kernel_regularizer=None, 982 bias_regularizer=None, 983 activity_regularizer=None, 984 kernel_constraint=None, 985 bias_constraint=None, 986 **kwargs): 987 super(Conv3DTranspose, self).__init__( 988 filters=filters, 989 kernel_size=kernel_size, 990 strides=strides, 991 padding=padding, 992 data_format=data_format, 993 activation=activations.get(activation), 994 use_bias=use_bias, 995 kernel_initializer=initializers.get(kernel_initializer), 996 bias_initializer=initializers.get(bias_initializer), 997 kernel_regularizer=regularizers.get(kernel_regularizer), 998 bias_regularizer=regularizers.get(bias_regularizer), 999 activity_regularizer=regularizers.get(activity_regularizer), 1000 kernel_constraint=constraints.get(kernel_constraint), 1001 bias_constraint=constraints.get(bias_constraint), 1002 **kwargs) 1003 1004 self.output_padding = output_padding 1005 if self.output_padding is not None: 1006 self.output_padding = conv_utils.normalize_tuple( 1007 self.output_padding, 3, 'output_padding') 1008 for stride, out_pad in zip(self.strides, self.output_padding): 1009 if out_pad >= stride: 1010 raise ValueError('Stride ' + str(self.strides) + ' must be ' 1011 'greater than output padding ' + 1012 str(self.output_padding)) 1013 1014 def build(self, input_shape): 1015 input_shape = tensor_shape.TensorShape(input_shape) 1016 if len(input_shape) != 5: 1017 raise ValueError('Inputs should have rank 5, received input shape:', 1018 str(input_shape)) 1019 if self.data_format == 'channels_first': 1020 channel_axis = 1 1021 else: 1022 channel_axis = -1 1023 if input_shape.dims[channel_axis].value is None: 1024 raise ValueError('The channel dimension of the inputs ' 1025 'should be defined, found None: ' + str(input_shape)) 1026 input_dim = int(input_shape[channel_axis]) 1027 kernel_shape = self.kernel_size + (self.filters, input_dim) 1028 self.input_spec = InputSpec(ndim=5, axes={channel_axis: input_dim}) 1029 1030 self.kernel = self.add_weight( 1031 'kernel', 1032 shape=kernel_shape, 1033 initializer=self.kernel_initializer, 1034 regularizer=self.kernel_regularizer, 1035 constraint=self.kernel_constraint, 1036 trainable=True, 1037 dtype=self.dtype) 1038 if self.use_bias: 1039 self.bias = self.add_weight( 1040 'bias', 1041 shape=(self.filters,), 1042 initializer=self.bias_initializer, 1043 regularizer=self.bias_regularizer, 1044 constraint=self.bias_constraint, 1045 trainable=True, 1046 dtype=self.dtype) 1047 else: 1048 self.bias = None 1049 self.built = True 1050 1051 def call(self, inputs): 1052 inputs_shape = array_ops.shape(inputs) 1053 batch_size = inputs_shape[0] 1054 if self.data_format == 'channels_first': 1055 d_axis, h_axis, w_axis = 2, 3, 4 1056 else: 1057 d_axis, h_axis, w_axis = 1, 2, 3 1058 1059 depth = inputs_shape[d_axis] 1060 height = inputs_shape[h_axis] 1061 width = inputs_shape[w_axis] 1062 1063 kernel_d, kernel_h, kernel_w = self.kernel_size 1064 stride_d, stride_h, stride_w = self.strides 1065 1066 if self.output_padding is None: 1067 out_pad_d = out_pad_h = out_pad_w = None 1068 else: 1069 out_pad_d, out_pad_h, out_pad_w = self.output_padding 1070 1071 # Infer the dynamic output shape: 1072 out_depth = conv_utils.deconv_output_length(depth, 1073 kernel_d, 1074 padding=self.padding, 1075 output_padding=out_pad_d, 1076 stride=stride_d) 1077 out_height = conv_utils.deconv_output_length(height, 1078 kernel_h, 1079 padding=self.padding, 1080 output_padding=out_pad_h, 1081 stride=stride_h) 1082 out_width = conv_utils.deconv_output_length(width, 1083 kernel_w, 1084 padding=self.padding, 1085 output_padding=out_pad_w, 1086 stride=stride_w) 1087 if self.data_format == 'channels_first': 1088 output_shape = (batch_size, self.filters, out_depth, out_height, 1089 out_width) 1090 strides = (1, 1, stride_d, stride_h, stride_w) 1091 else: 1092 output_shape = (batch_size, out_depth, out_height, out_width, 1093 self.filters) 1094 strides = (1, stride_d, stride_h, stride_w, 1) 1095 1096 output_shape_tensor = array_ops.stack(output_shape) 1097 outputs = nn.conv3d_transpose( 1098 inputs, 1099 self.kernel, 1100 output_shape_tensor, 1101 strides, 1102 data_format=conv_utils.convert_data_format(self.data_format, ndim=5), 1103 padding=self.padding.upper()) 1104 1105 if not context.executing_eagerly(): 1106 # Infer the static output shape: 1107 out_shape = self.compute_output_shape(inputs.shape) 1108 outputs.set_shape(out_shape) 1109 1110 if self.use_bias: 1111 outputs = nn.bias_add( 1112 outputs, 1113 self.bias, 1114 data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) 1115 1116 if self.activation is not None: 1117 return self.activation(outputs) 1118 return outputs 1119 1120 def compute_output_shape(self, input_shape): 1121 input_shape = tensor_shape.TensorShape(input_shape).as_list() 1122 output_shape = list(input_shape) 1123 if self.data_format == 'channels_first': 1124 c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4 1125 else: 1126 c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3 1127 1128 kernel_d, kernel_h, kernel_w = self.kernel_size 1129 stride_d, stride_h, stride_w = self.strides 1130 1131 if self.output_padding is None: 1132 out_pad_d = out_pad_h = out_pad_w = None 1133 else: 1134 out_pad_d, out_pad_h, out_pad_w = self.output_padding 1135 1136 output_shape[c_axis] = self.filters 1137 output_shape[d_axis] = conv_utils.deconv_output_length( 1138 output_shape[d_axis], 1139 kernel_d, 1140 padding=self.padding, 1141 output_padding=out_pad_d, 1142 stride=stride_d) 1143 output_shape[h_axis] = conv_utils.deconv_output_length( 1144 output_shape[h_axis], 1145 kernel_h, 1146 padding=self.padding, 1147 output_padding=out_pad_h, 1148 stride=stride_h) 1149 output_shape[w_axis] = conv_utils.deconv_output_length( 1150 output_shape[w_axis], 1151 kernel_w, 1152 padding=self.padding, 1153 output_padding=out_pad_w, 1154 stride=stride_w) 1155 return tensor_shape.TensorShape(output_shape) 1156 1157 def get_config(self): 1158 config = super(Conv3DTranspose, self).get_config() 1159 config.pop('dilation_rate') 1160 config['output_padding'] = self.output_padding 1161 return config 1162 1163 1164class SeparableConv(Conv): 1165 """Abstract base layer for separable nD convolution. 1166 1167 This layer performs a depthwise convolution that acts separately on 1168 channels, followed by a pointwise convolution that mixes channels. 1169 If `use_bias` is True and a bias initializer is provided, 1170 it adds a bias vector to the output. 1171 It then optionally applies an activation function to produce the final output. 1172 1173 Arguments: 1174 rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. 1175 filters: Integer, the dimensionality of the output space (i.e. the number 1176 of filters in the convolution). 1177 kernel_size: A tuple or list of integers specifying the spatial 1178 dimensions of the filters. Can be a single integer to specify the same 1179 value for all spatial dimensions. 1180 strides: A tuple or list of integers specifying the strides 1181 of the convolution. Can be a single integer to specify the same value for 1182 all spatial dimensions. 1183 Specifying any `stride` value != 1 is incompatible with specifying 1184 any `dilation_rate` value != 1. 1185 padding: One of `"valid"` or `"same"` (case-insensitive). 1186 data_format: A string, one of `channels_last` (default) or `channels_first`. 1187 The ordering of the dimensions in the inputs. 1188 `channels_last` corresponds to inputs with shape 1189 `(batch, ..., channels)` while `channels_first` corresponds to 1190 inputs with shape `(batch, channels, ...)`. 1191 dilation_rate: An integer or tuple/list of 2 integers, specifying 1192 the dilation rate to use for dilated convolution. 1193 Can be a single integer to specify the same value for 1194 all spatial dimensions. 1195 Currently, specifying any `dilation_rate` value != 1 is 1196 incompatible with specifying any stride value != 1. 1197 depth_multiplier: The number of depthwise convolution output channels for 1198 each input channel. The total number of depthwise convolution output 1199 channels will be equal to `num_filters_in * depth_multiplier`. 1200 activation: Activation function. Set it to None to maintain a 1201 linear activation. 1202 use_bias: Boolean, whether the layer uses a bias. 1203 depthwise_initializer: An initializer for the depthwise convolution kernel. 1204 pointwise_initializer: An initializer for the pointwise convolution kernel. 1205 bias_initializer: An initializer for the bias vector. If None, the default 1206 initializer will be used. 1207 depthwise_regularizer: Optional regularizer for the depthwise 1208 convolution kernel. 1209 pointwise_regularizer: Optional regularizer for the pointwise 1210 convolution kernel. 1211 bias_regularizer: Optional regularizer for the bias vector. 1212 activity_regularizer: Optional regularizer function for the output. 1213 depthwise_constraint: Optional projection function to be applied to the 1214 depthwise kernel after being updated by an `Optimizer` (e.g. used for 1215 norm constraints or value constraints for layer weights). The function 1216 must take as input the unprojected variable and must return the 1217 projected variable (which must have the same shape). Constraints are 1218 not safe to use when doing asynchronous distributed training. 1219 pointwise_constraint: Optional projection function to be applied to the 1220 pointwise kernel after being updated by an `Optimizer`. 1221 bias_constraint: Optional projection function to be applied to the 1222 bias after being updated by an `Optimizer`. 1223 trainable: Boolean, if `True` the weights of this layer will be marked as 1224 trainable (and listed in `layer.trainable_weights`). 1225 name: A string, the name of the layer. 1226 """ 1227 1228 def __init__(self, 1229 rank, 1230 filters, 1231 kernel_size, 1232 strides=1, 1233 padding='valid', 1234 data_format=None, 1235 dilation_rate=1, 1236 depth_multiplier=1, 1237 activation=None, 1238 use_bias=True, 1239 depthwise_initializer='glorot_uniform', 1240 pointwise_initializer='glorot_uniform', 1241 bias_initializer='zeros', 1242 depthwise_regularizer=None, 1243 pointwise_regularizer=None, 1244 bias_regularizer=None, 1245 activity_regularizer=None, 1246 depthwise_constraint=None, 1247 pointwise_constraint=None, 1248 bias_constraint=None, 1249 trainable=True, 1250 name=None, 1251 **kwargs): 1252 super(SeparableConv, self).__init__( 1253 rank=rank, 1254 filters=filters, 1255 kernel_size=kernel_size, 1256 strides=strides, 1257 padding=padding, 1258 data_format=data_format, 1259 dilation_rate=dilation_rate, 1260 activation=activations.get(activation), 1261 use_bias=use_bias, 1262 bias_initializer=initializers.get(bias_initializer), 1263 bias_regularizer=regularizers.get(bias_regularizer), 1264 activity_regularizer=regularizers.get(activity_regularizer), 1265 bias_constraint=bias_constraint, 1266 trainable=trainable, 1267 name=name, 1268 **kwargs) 1269 self.depth_multiplier = depth_multiplier 1270 self.depthwise_initializer = initializers.get(depthwise_initializer) 1271 self.pointwise_initializer = initializers.get(pointwise_initializer) 1272 self.depthwise_regularizer = regularizers.get(depthwise_regularizer) 1273 self.pointwise_regularizer = regularizers.get(pointwise_regularizer) 1274 self.depthwise_constraint = constraints.get(depthwise_constraint) 1275 self.pointwise_constraint = constraints.get(pointwise_constraint) 1276 1277 def build(self, input_shape): 1278 input_shape = tensor_shape.TensorShape(input_shape) 1279 if self.data_format == 'channels_first': 1280 channel_axis = 1 1281 else: 1282 channel_axis = -1 1283 if input_shape.dims[channel_axis].value is None: 1284 raise ValueError('The channel dimension of the inputs ' 1285 'should be defined. Found `None`.') 1286 input_dim = int(input_shape[channel_axis]) 1287 self.input_spec = InputSpec(ndim=self.rank + 2, 1288 axes={channel_axis: input_dim}) 1289 depthwise_kernel_shape = self.kernel_size + (input_dim, 1290 self.depth_multiplier) 1291 pointwise_kernel_shape = ( 1292 1,) * self.rank + (self.depth_multiplier * input_dim, self.filters) 1293 1294 self.depthwise_kernel = self.add_weight( 1295 name='depthwise_kernel', 1296 shape=depthwise_kernel_shape, 1297 initializer=self.depthwise_initializer, 1298 regularizer=self.depthwise_regularizer, 1299 constraint=self.depthwise_constraint, 1300 trainable=True, 1301 dtype=self.dtype) 1302 self.pointwise_kernel = self.add_weight( 1303 name='pointwise_kernel', 1304 shape=pointwise_kernel_shape, 1305 initializer=self.pointwise_initializer, 1306 regularizer=self.pointwise_regularizer, 1307 constraint=self.pointwise_constraint, 1308 trainable=True, 1309 dtype=self.dtype) 1310 if self.use_bias: 1311 self.bias = self.add_weight( 1312 name='bias', 1313 shape=(self.filters,), 1314 initializer=self.bias_initializer, 1315 regularizer=self.bias_regularizer, 1316 constraint=self.bias_constraint, 1317 trainable=True, 1318 dtype=self.dtype) 1319 else: 1320 self.bias = None 1321 self.built = True 1322 1323 def call(self, inputs): 1324 raise NotImplementedError 1325 1326 def get_config(self): 1327 config = { 1328 'filters': 1329 self.filters, 1330 'kernel_size': 1331 self.kernel_size, 1332 'strides': 1333 self.strides, 1334 'padding': 1335 self.padding, 1336 'data_format': 1337 self.data_format, 1338 'depth_multiplier': 1339 self.depth_multiplier, 1340 'dilation_rate': 1341 self.dilation_rate, 1342 'activation': 1343 activations.serialize(self.activation), 1344 'use_bias': 1345 self.use_bias, 1346 'depthwise_initializer': 1347 initializers.serialize(self.depthwise_initializer), 1348 'pointwise_initializer': 1349 initializers.serialize(self.pointwise_initializer), 1350 'bias_initializer': 1351 initializers.serialize(self.bias_initializer), 1352 'depthwise_regularizer': 1353 regularizers.serialize(self.depthwise_regularizer), 1354 'pointwise_regularizer': 1355 regularizers.serialize(self.pointwise_regularizer), 1356 'bias_regularizer': 1357 regularizers.serialize(self.bias_regularizer), 1358 'activity_regularizer': 1359 regularizers.serialize(self.activity_regularizer), 1360 'depthwise_constraint': 1361 constraints.serialize(self.depthwise_constraint), 1362 'pointwise_constraint': 1363 constraints.serialize(self.pointwise_constraint), 1364 'bias_constraint': 1365 constraints.serialize(self.bias_constraint) 1366 } 1367 base_config = super(SeparableConv, self).get_config() 1368 return dict(list(base_config.items()) + list(config.items())) 1369 1370 1371@keras_export('keras.layers.SeparableConv1D', 1372 'keras.layers.SeparableConvolution1D') 1373class SeparableConv1D(SeparableConv): 1374 """Depthwise separable 1D convolution. 1375 1376 This layer performs a depthwise convolution that acts separately on 1377 channels, followed by a pointwise convolution that mixes channels. 1378 If `use_bias` is True and a bias initializer is provided, 1379 it adds a bias vector to the output. 1380 It then optionally applies an activation function to produce the final output. 1381 1382 Arguments: 1383 filters: Integer, the dimensionality of the output space (i.e. the number 1384 of filters in the convolution). 1385 kernel_size: A single integer specifying the spatial 1386 dimensions of the filters. 1387 strides: A single integer specifying the strides 1388 of the convolution. 1389 Specifying any `stride` value != 1 is incompatible with specifying 1390 any `dilation_rate` value != 1. 1391 padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive). 1392 data_format: A string, one of `channels_last` (default) or `channels_first`. 1393 The ordering of the dimensions in the inputs. 1394 `channels_last` corresponds to inputs with shape 1395 `(batch, length, channels)` while `channels_first` corresponds to 1396 inputs with shape `(batch, channels, length)`. 1397 dilation_rate: A single integer, specifying 1398 the dilation rate to use for dilated convolution. 1399 Currently, specifying any `dilation_rate` value != 1 is 1400 incompatible with specifying any stride value != 1. 1401 depth_multiplier: The number of depthwise convolution output channels for 1402 each input channel. The total number of depthwise convolution output 1403 channels will be equal to `num_filters_in * depth_multiplier`. 1404 activation: Activation function. Set it to None to maintain a 1405 linear activation. 1406 use_bias: Boolean, whether the layer uses a bias. 1407 depthwise_initializer: An initializer for the depthwise convolution kernel. 1408 pointwise_initializer: An initializer for the pointwise convolution kernel. 1409 bias_initializer: An initializer for the bias vector. If None, the default 1410 initializer will be used. 1411 depthwise_regularizer: Optional regularizer for the depthwise 1412 convolution kernel. 1413 pointwise_regularizer: Optional regularizer for the pointwise 1414 convolution kernel. 1415 bias_regularizer: Optional regularizer for the bias vector. 1416 activity_regularizer: Optional regularizer function for the output. 1417 depthwise_constraint: Optional projection function to be applied to the 1418 depthwise kernel after being updated by an `Optimizer` (e.g. used for 1419 norm constraints or value constraints for layer weights). The function 1420 must take as input the unprojected variable and must return the 1421 projected variable (which must have the same shape). Constraints are 1422 not safe to use when doing asynchronous distributed training. 1423 pointwise_constraint: Optional projection function to be applied to the 1424 pointwise kernel after being updated by an `Optimizer`. 1425 bias_constraint: Optional projection function to be applied to the 1426 bias after being updated by an `Optimizer`. 1427 trainable: Boolean, if `True` the weights of this layer will be marked as 1428 trainable (and listed in `layer.trainable_weights`). 1429 name: A string, the name of the layer. 1430 """ 1431 1432 def __init__(self, 1433 filters, 1434 kernel_size, 1435 strides=1, 1436 padding='valid', 1437 data_format=None, 1438 dilation_rate=1, 1439 depth_multiplier=1, 1440 activation=None, 1441 use_bias=True, 1442 depthwise_initializer='glorot_uniform', 1443 pointwise_initializer='glorot_uniform', 1444 bias_initializer='zeros', 1445 depthwise_regularizer=None, 1446 pointwise_regularizer=None, 1447 bias_regularizer=None, 1448 activity_regularizer=None, 1449 depthwise_constraint=None, 1450 pointwise_constraint=None, 1451 bias_constraint=None, 1452 **kwargs): 1453 super(SeparableConv1D, self).__init__( 1454 rank=1, 1455 filters=filters, 1456 kernel_size=kernel_size, 1457 strides=strides, 1458 padding=padding, 1459 data_format=data_format, 1460 dilation_rate=dilation_rate, 1461 depth_multiplier=depth_multiplier, 1462 activation=activations.get(activation), 1463 use_bias=use_bias, 1464 depthwise_initializer=initializers.get(depthwise_initializer), 1465 pointwise_initializer=initializers.get(pointwise_initializer), 1466 bias_initializer=initializers.get(bias_initializer), 1467 depthwise_regularizer=regularizers.get(depthwise_regularizer), 1468 pointwise_regularizer=regularizers.get(pointwise_regularizer), 1469 bias_regularizer=regularizers.get(bias_regularizer), 1470 activity_regularizer=regularizers.get(activity_regularizer), 1471 depthwise_constraint=constraints.get(depthwise_constraint), 1472 pointwise_constraint=constraints.get(pointwise_constraint), 1473 bias_constraint=constraints.get(bias_constraint), 1474 **kwargs) 1475 1476 def call(self, inputs): 1477 if self.padding == 'causal': 1478 inputs = array_ops.pad(inputs, self._compute_causal_padding()) 1479 if self.data_format == 'channels_last': 1480 strides = (1,) + self.strides * 2 + (1,) 1481 spatial_start_dim = 1 1482 else: 1483 strides = (1, 1) + self.strides * 2 1484 spatial_start_dim = 2 1485 1486 # Explicitly broadcast inputs and kernels to 4D. 1487 # TODO(fchollet): refactor when a native separable_conv1d op is available. 1488 inputs = array_ops.expand_dims(inputs, spatial_start_dim) 1489 depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) 1490 pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0) 1491 dilation_rate = (1,) + self.dilation_rate 1492 1493 if self.padding == 'causal': 1494 op_padding = 'valid' 1495 else: 1496 op_padding = self.padding 1497 outputs = nn.separable_conv2d( 1498 inputs, 1499 depthwise_kernel, 1500 pointwise_kernel, 1501 strides=strides, 1502 padding=op_padding.upper(), 1503 rate=dilation_rate, 1504 data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) 1505 1506 if self.use_bias: 1507 outputs = nn.bias_add( 1508 outputs, 1509 self.bias, 1510 data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) 1511 1512 outputs = array_ops.squeeze(outputs, [spatial_start_dim]) 1513 1514 if self.activation is not None: 1515 return self.activation(outputs) 1516 return outputs 1517 1518 1519@keras_export('keras.layers.SeparableConv2D', 1520 'keras.layers.SeparableConvolution2D') 1521class SeparableConv2D(SeparableConv): 1522 """Depthwise separable 2D convolution. 1523 1524 Separable convolutions consist in first performing 1525 a depthwise spatial convolution 1526 (which acts on each input channel separately) 1527 followed by a pointwise convolution which mixes together the resulting 1528 output channels. The `depth_multiplier` argument controls how many 1529 output channels are generated per input channel in the depthwise step. 1530 1531 Intuitively, separable convolutions can be understood as 1532 a way to factorize a convolution kernel into two smaller kernels, 1533 or as an extreme version of an Inception block. 1534 1535 Arguments: 1536 filters: Integer, the dimensionality of the output space 1537 (i.e. the number of output filters in the convolution). 1538 kernel_size: An integer or tuple/list of 2 integers, specifying the 1539 height and width of the 2D convolution window. 1540 Can be a single integer to specify the same value for 1541 all spatial dimensions. 1542 strides: An integer or tuple/list of 2 integers, 1543 specifying the strides of the convolution along the height and width. 1544 Can be a single integer to specify the same value for 1545 all spatial dimensions. 1546 Specifying any stride value != 1 is incompatible with specifying 1547 any `dilation_rate` value != 1. 1548 padding: one of `"valid"` or `"same"` (case-insensitive). 1549 data_format: A string, 1550 one of `channels_last` (default) or `channels_first`. 1551 The ordering of the dimensions in the inputs. 1552 `channels_last` corresponds to inputs with shape 1553 `(batch, height, width, channels)` while `channels_first` 1554 corresponds to inputs with shape 1555 `(batch, channels, height, width)`. 1556 It defaults to the `image_data_format` value found in your 1557 Keras config file at `~/.keras/keras.json`. 1558 If you never set it, then it will be "channels_last". 1559 dilation_rate: An integer or tuple/list of 2 integers, specifying 1560 the dilation rate to use for dilated convolution. 1561 Currently, specifying any `dilation_rate` value != 1 is 1562 incompatible with specifying any `strides` value != 1. 1563 depth_multiplier: The number of depthwise convolution output channels 1564 for each input channel. 1565 The total number of depthwise convolution output 1566 channels will be equal to `filters_in * depth_multiplier`. 1567 activation: Activation function to use. 1568 If you don't specify anything, no activation is applied 1569 (ie. "linear" activation: `a(x) = x`). 1570 use_bias: Boolean, whether the layer uses a bias vector. 1571 depthwise_initializer: Initializer for the depthwise kernel matrix. 1572 pointwise_initializer: Initializer for the pointwise kernel matrix. 1573 bias_initializer: Initializer for the bias vector. 1574 depthwise_regularizer: Regularizer function applied to 1575 the depthwise kernel matrix. 1576 pointwise_regularizer: Regularizer function applied to 1577 the pointwise kernel matrix. 1578 bias_regularizer: Regularizer function applied to the bias vector. 1579 activity_regularizer: Regularizer function applied to 1580 the output of the layer (its "activation").. 1581 depthwise_constraint: Constraint function applied to 1582 the depthwise kernel matrix. 1583 pointwise_constraint: Constraint function applied to 1584 the pointwise kernel matrix. 1585 bias_constraint: Constraint function applied to the bias vector. 1586 1587 Input shape: 1588 4D tensor with shape: 1589 `(batch, channels, rows, cols)` if data_format='channels_first' 1590 or 4D tensor with shape: 1591 `(batch, rows, cols, channels)` if data_format='channels_last'. 1592 1593 Output shape: 1594 4D tensor with shape: 1595 `(batch, filters, new_rows, new_cols)` if data_format='channels_first' 1596 or 4D tensor with shape: 1597 `(batch, new_rows, new_cols, filters)` if data_format='channels_last'. 1598 `rows` and `cols` values might have changed due to padding. 1599 """ 1600 1601 def __init__(self, 1602 filters, 1603 kernel_size, 1604 strides=(1, 1), 1605 padding='valid', 1606 data_format=None, 1607 dilation_rate=(1, 1), 1608 depth_multiplier=1, 1609 activation=None, 1610 use_bias=True, 1611 depthwise_initializer='glorot_uniform', 1612 pointwise_initializer='glorot_uniform', 1613 bias_initializer='zeros', 1614 depthwise_regularizer=None, 1615 pointwise_regularizer=None, 1616 bias_regularizer=None, 1617 activity_regularizer=None, 1618 depthwise_constraint=None, 1619 pointwise_constraint=None, 1620 bias_constraint=None, 1621 **kwargs): 1622 super(SeparableConv2D, self).__init__( 1623 rank=2, 1624 filters=filters, 1625 kernel_size=kernel_size, 1626 strides=strides, 1627 padding=padding, 1628 data_format=data_format, 1629 dilation_rate=dilation_rate, 1630 depth_multiplier=depth_multiplier, 1631 activation=activations.get(activation), 1632 use_bias=use_bias, 1633 depthwise_initializer=initializers.get(depthwise_initializer), 1634 pointwise_initializer=initializers.get(pointwise_initializer), 1635 bias_initializer=initializers.get(bias_initializer), 1636 depthwise_regularizer=regularizers.get(depthwise_regularizer), 1637 pointwise_regularizer=regularizers.get(pointwise_regularizer), 1638 bias_regularizer=regularizers.get(bias_regularizer), 1639 activity_regularizer=regularizers.get(activity_regularizer), 1640 depthwise_constraint=constraints.get(depthwise_constraint), 1641 pointwise_constraint=constraints.get(pointwise_constraint), 1642 bias_constraint=constraints.get(bias_constraint), 1643 **kwargs) 1644 1645 def call(self, inputs): 1646 # Apply the actual ops. 1647 if self.data_format == 'channels_last': 1648 strides = (1,) + self.strides + (1,) 1649 else: 1650 strides = (1, 1) + self.strides 1651 outputs = nn.separable_conv2d( 1652 inputs, 1653 self.depthwise_kernel, 1654 self.pointwise_kernel, 1655 strides=strides, 1656 padding=self.padding.upper(), 1657 rate=self.dilation_rate, 1658 data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) 1659 1660 if self.use_bias: 1661 outputs = nn.bias_add( 1662 outputs, 1663 self.bias, 1664 data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) 1665 1666 if self.activation is not None: 1667 return self.activation(outputs) 1668 return outputs 1669 1670 1671@keras_export('keras.layers.DepthwiseConv2D') 1672class DepthwiseConv2D(Conv2D): 1673 """Depthwise separable 2D convolution. 1674 1675 Depthwise Separable convolutions consists in performing 1676 just the first step in a depthwise spatial convolution 1677 (which acts on each input channel separately). 1678 The `depth_multiplier` argument controls how many 1679 output channels are generated per input channel in the depthwise step. 1680 1681 Arguments: 1682 kernel_size: An integer or tuple/list of 2 integers, specifying the 1683 height and width of the 2D convolution window. 1684 Can be a single integer to specify the same value for 1685 all spatial dimensions. 1686 strides: An integer or tuple/list of 2 integers, 1687 specifying the strides of the convolution along the height and width. 1688 Can be a single integer to specify the same value for 1689 all spatial dimensions. 1690 Specifying any stride value != 1 is incompatible with specifying 1691 any `dilation_rate` value != 1. 1692 padding: one of `'valid'` or `'same'` (case-insensitive). 1693 depth_multiplier: The number of depthwise convolution output channels 1694 for each input channel. 1695 The total number of depthwise convolution output 1696 channels will be equal to `filters_in * depth_multiplier`. 1697 data_format: A string, 1698 one of `channels_last` (default) or `channels_first`. 1699 The ordering of the dimensions in the inputs. 1700 `channels_last` corresponds to inputs with shape 1701 `(batch, height, width, channels)` while `channels_first` 1702 corresponds to inputs with shape 1703 `(batch, channels, height, width)`. 1704 It defaults to the `image_data_format` value found in your 1705 Keras config file at `~/.keras/keras.json`. 1706 If you never set it, then it will be 'channels_last'. 1707 activation: Activation function to use. 1708 If you don't specify anything, no activation is applied 1709 (ie. 'linear' activation: `a(x) = x`). 1710 use_bias: Boolean, whether the layer uses a bias vector. 1711 depthwise_initializer: Initializer for the depthwise kernel matrix. 1712 bias_initializer: Initializer for the bias vector. 1713 depthwise_regularizer: Regularizer function applied to 1714 the depthwise kernel matrix. 1715 bias_regularizer: Regularizer function applied to the bias vector. 1716 activity_regularizer: Regularizer function applied to 1717 the output of the layer (its 'activation'). 1718 depthwise_constraint: Constraint function applied to 1719 the depthwise kernel matrix. 1720 bias_constraint: Constraint function applied to the bias vector. 1721 1722 Input shape: 1723 4D tensor with shape: 1724 `[batch, channels, rows, cols]` if data_format='channels_first' 1725 or 4D tensor with shape: 1726 `[batch, rows, cols, channels]` if data_format='channels_last'. 1727 1728 Output shape: 1729 4D tensor with shape: 1730 `[batch, filters, new_rows, new_cols]` if data_format='channels_first' 1731 or 4D tensor with shape: 1732 `[batch, new_rows, new_cols, filters]` if data_format='channels_last'. 1733 `rows` and `cols` values might have changed due to padding. 1734 """ 1735 1736 def __init__(self, 1737 kernel_size, 1738 strides=(1, 1), 1739 padding='valid', 1740 depth_multiplier=1, 1741 data_format=None, 1742 activation=None, 1743 use_bias=True, 1744 depthwise_initializer='glorot_uniform', 1745 bias_initializer='zeros', 1746 depthwise_regularizer=None, 1747 bias_regularizer=None, 1748 activity_regularizer=None, 1749 depthwise_constraint=None, 1750 bias_constraint=None, 1751 **kwargs): 1752 super(DepthwiseConv2D, self).__init__( 1753 filters=None, 1754 kernel_size=kernel_size, 1755 strides=strides, 1756 padding=padding, 1757 data_format=data_format, 1758 activation=activation, 1759 use_bias=use_bias, 1760 bias_regularizer=bias_regularizer, 1761 activity_regularizer=activity_regularizer, 1762 bias_constraint=bias_constraint, 1763 **kwargs) 1764 self.depth_multiplier = depth_multiplier 1765 self.depthwise_initializer = initializers.get(depthwise_initializer) 1766 self.depthwise_regularizer = regularizers.get(depthwise_regularizer) 1767 self.depthwise_constraint = constraints.get(depthwise_constraint) 1768 self.bias_initializer = initializers.get(bias_initializer) 1769 1770 def build(self, input_shape): 1771 if len(input_shape) < 4: 1772 raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. ' 1773 'Received input shape:', str(input_shape)) 1774 if self.data_format == 'channels_first': 1775 channel_axis = 1 1776 else: 1777 channel_axis = 3 1778 if input_shape.dims[channel_axis].value is None: 1779 raise ValueError('The channel dimension of the inputs to ' 1780 '`DepthwiseConv2D` ' 1781 'should be defined. Found `None`.') 1782 input_dim = int(input_shape[channel_axis]) 1783 depthwise_kernel_shape = (self.kernel_size[0], 1784 self.kernel_size[1], 1785 input_dim, 1786 self.depth_multiplier) 1787 1788 self.depthwise_kernel = self.add_weight( 1789 shape=depthwise_kernel_shape, 1790 initializer=self.depthwise_initializer, 1791 name='depthwise_kernel', 1792 regularizer=self.depthwise_regularizer, 1793 constraint=self.depthwise_constraint) 1794 1795 if self.use_bias: 1796 self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,), 1797 initializer=self.bias_initializer, 1798 name='bias', 1799 regularizer=self.bias_regularizer, 1800 constraint=self.bias_constraint) 1801 else: 1802 self.bias = None 1803 # Set input spec. 1804 self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim}) 1805 self.built = True 1806 1807 def call(self, inputs): 1808 outputs = backend.depthwise_conv2d( 1809 inputs, 1810 self.depthwise_kernel, 1811 strides=self.strides, 1812 padding=self.padding, 1813 dilation_rate=self.dilation_rate, 1814 data_format=self.data_format) 1815 1816 if self.use_bias: 1817 outputs = backend.bias_add( 1818 outputs, 1819 self.bias, 1820 data_format=self.data_format) 1821 1822 if self.activation is not None: 1823 return self.activation(outputs) 1824 1825 return outputs 1826 1827 @tf_utils.shape_type_conversion 1828 def compute_output_shape(self, input_shape): 1829 if self.data_format == 'channels_first': 1830 rows = input_shape[2] 1831 cols = input_shape[3] 1832 out_filters = input_shape[1] * self.depth_multiplier 1833 elif self.data_format == 'channels_last': 1834 rows = input_shape[1] 1835 cols = input_shape[2] 1836 out_filters = input_shape[3] * self.depth_multiplier 1837 1838 rows = conv_utils.conv_output_length(rows, self.kernel_size[0], 1839 self.padding, 1840 self.strides[0]) 1841 cols = conv_utils.conv_output_length(cols, self.kernel_size[1], 1842 self.padding, 1843 self.strides[1]) 1844 if self.data_format == 'channels_first': 1845 return (input_shape[0], out_filters, rows, cols) 1846 elif self.data_format == 'channels_last': 1847 return (input_shape[0], rows, cols, out_filters) 1848 1849 def get_config(self): 1850 config = super(DepthwiseConv2D, self).get_config() 1851 config.pop('filters') 1852 config.pop('kernel_initializer') 1853 config.pop('kernel_regularizer') 1854 config.pop('kernel_constraint') 1855 config['depth_multiplier'] = self.depth_multiplier 1856 config['depthwise_initializer'] = initializers.serialize( 1857 self.depthwise_initializer) 1858 config['depthwise_regularizer'] = regularizers.serialize( 1859 self.depthwise_regularizer) 1860 config['depthwise_constraint'] = constraints.serialize( 1861 self.depthwise_constraint) 1862 return config 1863 1864 1865@keras_export('keras.layers.UpSampling1D') 1866class UpSampling1D(Layer): 1867 """Upsampling layer for 1D inputs. 1868 1869 Repeats each temporal step `size` times along the time axis. 1870 1871 Arguments: 1872 size: Integer. Upsampling factor. 1873 1874 Input shape: 1875 3D tensor with shape: `(batch, steps, features)`. 1876 1877 Output shape: 1878 3D tensor with shape: `(batch, upsampled_steps, features)`. 1879 """ 1880 1881 def __init__(self, size=2, **kwargs): 1882 super(UpSampling1D, self).__init__(**kwargs) 1883 self.size = int(size) 1884 self.input_spec = InputSpec(ndim=3) 1885 1886 def compute_output_shape(self, input_shape): 1887 input_shape = tensor_shape.TensorShape(input_shape).as_list() 1888 size = self.size * input_shape[1] if input_shape[1] is not None else None 1889 return tensor_shape.TensorShape([input_shape[0], size, input_shape[2]]) 1890 1891 def call(self, inputs): 1892 output = backend.repeat_elements(inputs, self.size, axis=1) 1893 return output 1894 1895 def get_config(self): 1896 config = {'size': self.size} 1897 base_config = super(UpSampling1D, self).get_config() 1898 return dict(list(base_config.items()) + list(config.items())) 1899 1900 1901@keras_export('keras.layers.UpSampling2D') 1902class UpSampling2D(Layer): 1903 """Upsampling layer for 2D inputs. 1904 1905 Repeats the rows and columns of the data 1906 by `size[0]` and `size[1]` respectively. 1907 1908 Arguments: 1909 size: Int, or tuple of 2 integers. 1910 The upsampling factors for rows and columns. 1911 data_format: A string, 1912 one of `channels_last` (default) or `channels_first`. 1913 The ordering of the dimensions in the inputs. 1914 `channels_last` corresponds to inputs with shape 1915 `(batch, height, width, channels)` while `channels_first` 1916 corresponds to inputs with shape 1917 `(batch, channels, height, width)`. 1918 It defaults to the `image_data_format` value found in your 1919 Keras config file at `~/.keras/keras.json`. 1920 If you never set it, then it will be "channels_last". 1921 interpolation: A string, one of `nearest` or `bilinear`. 1922 1923 Input shape: 1924 4D tensor with shape: 1925 - If `data_format` is `"channels_last"`: 1926 `(batch, rows, cols, channels)` 1927 - If `data_format` is `"channels_first"`: 1928 `(batch, channels, rows, cols)` 1929 1930 Output shape: 1931 4D tensor with shape: 1932 - If `data_format` is `"channels_last"`: 1933 `(batch, upsampled_rows, upsampled_cols, channels)` 1934 - If `data_format` is `"channels_first"`: 1935 `(batch, channels, upsampled_rows, upsampled_cols)` 1936 """ 1937 1938 def __init__(self, 1939 size=(2, 2), 1940 data_format=None, 1941 interpolation='nearest', 1942 **kwargs): 1943 super(UpSampling2D, self).__init__(**kwargs) 1944 self.data_format = conv_utils.normalize_data_format(data_format) 1945 self.size = conv_utils.normalize_tuple(size, 2, 'size') 1946 if interpolation not in {'nearest', 'bilinear'}: 1947 raise ValueError('`interpolation` argument should be one of `"nearest"` ' 1948 'or `"bilinear"`.') 1949 self.interpolation = interpolation 1950 self.input_spec = InputSpec(ndim=4) 1951 1952 def compute_output_shape(self, input_shape): 1953 input_shape = tensor_shape.TensorShape(input_shape).as_list() 1954 if self.data_format == 'channels_first': 1955 height = self.size[0] * input_shape[ 1956 2] if input_shape[2] is not None else None 1957 width = self.size[1] * input_shape[ 1958 3] if input_shape[3] is not None else None 1959 return tensor_shape.TensorShape( 1960 [input_shape[0], input_shape[1], height, width]) 1961 else: 1962 height = self.size[0] * input_shape[ 1963 1] if input_shape[1] is not None else None 1964 width = self.size[1] * input_shape[ 1965 2] if input_shape[2] is not None else None 1966 return tensor_shape.TensorShape( 1967 [input_shape[0], height, width, input_shape[3]]) 1968 1969 def call(self, inputs): 1970 return backend.resize_images( 1971 inputs, self.size[0], self.size[1], self.data_format, 1972 interpolation=self.interpolation) 1973 1974 def get_config(self): 1975 config = {'size': self.size, 'data_format': self.data_format} 1976 base_config = super(UpSampling2D, self).get_config() 1977 return dict(list(base_config.items()) + list(config.items())) 1978 1979 1980@keras_export('keras.layers.UpSampling3D') 1981class UpSampling3D(Layer): 1982 """Upsampling layer for 3D inputs. 1983 1984 Repeats the 1st, 2nd and 3rd dimensions 1985 of the data by `size[0]`, `size[1]` and `size[2]` respectively. 1986 1987 Arguments: 1988 size: Int, or tuple of 3 integers. 1989 The upsampling factors for dim1, dim2 and dim3. 1990 data_format: A string, 1991 one of `channels_last` (default) or `channels_first`. 1992 The ordering of the dimensions in the inputs. 1993 `channels_last` corresponds to inputs with shape 1994 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 1995 while `channels_first` corresponds to inputs with shape 1996 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 1997 It defaults to the `image_data_format` value found in your 1998 Keras config file at `~/.keras/keras.json`. 1999 If you never set it, then it will be "channels_last". 2000 2001 Input shape: 2002 5D tensor with shape: 2003 - If `data_format` is `"channels_last"`: 2004 `(batch, dim1, dim2, dim3, channels)` 2005 - If `data_format` is `"channels_first"`: 2006 `(batch, channels, dim1, dim2, dim3)` 2007 2008 Output shape: 2009 5D tensor with shape: 2010 - If `data_format` is `"channels_last"`: 2011 `(batch, upsampled_dim1, upsampled_dim2, upsampled_dim3, channels)` 2012 - If `data_format` is `"channels_first"`: 2013 `(batch, channels, upsampled_dim1, upsampled_dim2, upsampled_dim3)` 2014 """ 2015 2016 def __init__(self, size=(2, 2, 2), data_format=None, **kwargs): 2017 self.data_format = conv_utils.normalize_data_format(data_format) 2018 self.size = conv_utils.normalize_tuple(size, 3, 'size') 2019 self.input_spec = InputSpec(ndim=5) 2020 super(UpSampling3D, self).__init__(**kwargs) 2021 2022 def compute_output_shape(self, input_shape): 2023 input_shape = tensor_shape.TensorShape(input_shape).as_list() 2024 if self.data_format == 'channels_first': 2025 dim1 = self.size[0] * input_shape[ 2026 2] if input_shape[2] is not None else None 2027 dim2 = self.size[1] * input_shape[ 2028 3] if input_shape[3] is not None else None 2029 dim3 = self.size[2] * input_shape[ 2030 4] if input_shape[4] is not None else None 2031 return tensor_shape.TensorShape( 2032 [input_shape[0], input_shape[1], dim1, dim2, dim3]) 2033 else: 2034 dim1 = self.size[0] * input_shape[ 2035 1] if input_shape[1] is not None else None 2036 dim2 = self.size[1] * input_shape[ 2037 2] if input_shape[2] is not None else None 2038 dim3 = self.size[2] * input_shape[ 2039 3] if input_shape[3] is not None else None 2040 return tensor_shape.TensorShape( 2041 [input_shape[0], dim1, dim2, dim3, input_shape[4]]) 2042 2043 def call(self, inputs): 2044 return backend.resize_volumes( 2045 inputs, self.size[0], self.size[1], self.size[2], self.data_format) 2046 2047 def get_config(self): 2048 config = {'size': self.size, 'data_format': self.data_format} 2049 base_config = super(UpSampling3D, self).get_config() 2050 return dict(list(base_config.items()) + list(config.items())) 2051 2052 2053@keras_export('keras.layers.ZeroPadding1D') 2054class ZeroPadding1D(Layer): 2055 """Zero-padding layer for 1D input (e.g. temporal sequence). 2056 2057 Arguments: 2058 padding: Int, or tuple of int (length 2), or dictionary. 2059 - If int: 2060 How many zeros to add at the beginning and end of 2061 the padding dimension (axis 1). 2062 - If tuple of int (length 2): 2063 How many zeros to add at the beginning and at the end of 2064 the padding dimension (`(left_pad, right_pad)`). 2065 2066 Input shape: 2067 3D tensor with shape `(batch, axis_to_pad, features)` 2068 2069 Output shape: 2070 3D tensor with shape `(batch, padded_axis, features)` 2071 """ 2072 2073 def __init__(self, padding=1, **kwargs): 2074 super(ZeroPadding1D, self).__init__(**kwargs) 2075 self.padding = conv_utils.normalize_tuple(padding, 2, 'padding') 2076 self.input_spec = InputSpec(ndim=3) 2077 2078 def compute_output_shape(self, input_shape): 2079 if input_shape[1] is not None: 2080 length = input_shape[1] + self.padding[0] + self.padding[1] 2081 else: 2082 length = None 2083 return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]]) 2084 2085 def call(self, inputs): 2086 return backend.temporal_padding(inputs, padding=self.padding) 2087 2088 def get_config(self): 2089 config = {'padding': self.padding} 2090 base_config = super(ZeroPadding1D, self).get_config() 2091 return dict(list(base_config.items()) + list(config.items())) 2092 2093 2094@keras_export('keras.layers.ZeroPadding2D') 2095class ZeroPadding2D(Layer): 2096 """Zero-padding layer for 2D input (e.g. picture). 2097 2098 This layer can add rows and columns of zeros 2099 at the top, bottom, left and right side of an image tensor. 2100 2101 Arguments: 2102 padding: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. 2103 - If int: the same symmetric padding 2104 is applied to height and width. 2105 - If tuple of 2 ints: 2106 interpreted as two different 2107 symmetric padding values for height and width: 2108 `(symmetric_height_pad, symmetric_width_pad)`. 2109 - If tuple of 2 tuples of 2 ints: 2110 interpreted as 2111 `((top_pad, bottom_pad), (left_pad, right_pad))` 2112 data_format: A string, 2113 one of `channels_last` (default) or `channels_first`. 2114 The ordering of the dimensions in the inputs. 2115 `channels_last` corresponds to inputs with shape 2116 `(batch, height, width, channels)` while `channels_first` 2117 corresponds to inputs with shape 2118 `(batch, channels, height, width)`. 2119 It defaults to the `image_data_format` value found in your 2120 Keras config file at `~/.keras/keras.json`. 2121 If you never set it, then it will be "channels_last". 2122 2123 Input shape: 2124 4D tensor with shape: 2125 - If `data_format` is `"channels_last"`: 2126 `(batch, rows, cols, channels)` 2127 - If `data_format` is `"channels_first"`: 2128 `(batch, channels, rows, cols)` 2129 2130 Output shape: 2131 4D tensor with shape: 2132 - If `data_format` is `"channels_last"`: 2133 `(batch, padded_rows, padded_cols, channels)` 2134 - If `data_format` is `"channels_first"`: 2135 `(batch, channels, padded_rows, padded_cols)` 2136 """ 2137 2138 def __init__(self, padding=(1, 1), data_format=None, **kwargs): 2139 super(ZeroPadding2D, self).__init__(**kwargs) 2140 self.data_format = conv_utils.normalize_data_format(data_format) 2141 if isinstance(padding, int): 2142 self.padding = ((padding, padding), (padding, padding)) 2143 elif hasattr(padding, '__len__'): 2144 if len(padding) != 2: 2145 raise ValueError('`padding` should have two elements. ' 2146 'Found: ' + str(padding)) 2147 height_padding = conv_utils.normalize_tuple(padding[0], 2, 2148 '1st entry of padding') 2149 width_padding = conv_utils.normalize_tuple(padding[1], 2, 2150 '2nd entry of padding') 2151 self.padding = (height_padding, width_padding) 2152 else: 2153 raise ValueError('`padding` should be either an int, ' 2154 'a tuple of 2 ints ' 2155 '(symmetric_height_pad, symmetric_width_pad), ' 2156 'or a tuple of 2 tuples of 2 ints ' 2157 '((top_pad, bottom_pad), (left_pad, right_pad)). ' 2158 'Found: ' + str(padding)) 2159 self.input_spec = InputSpec(ndim=4) 2160 2161 def compute_output_shape(self, input_shape): 2162 input_shape = tensor_shape.TensorShape(input_shape).as_list() 2163 if self.data_format == 'channels_first': 2164 if input_shape[2] is not None: 2165 rows = input_shape[2] + self.padding[0][0] + self.padding[0][1] 2166 else: 2167 rows = None 2168 if input_shape[3] is not None: 2169 cols = input_shape[3] + self.padding[1][0] + self.padding[1][1] 2170 else: 2171 cols = None 2172 return tensor_shape.TensorShape( 2173 [input_shape[0], input_shape[1], rows, cols]) 2174 elif self.data_format == 'channels_last': 2175 if input_shape[1] is not None: 2176 rows = input_shape[1] + self.padding[0][0] + self.padding[0][1] 2177 else: 2178 rows = None 2179 if input_shape[2] is not None: 2180 cols = input_shape[2] + self.padding[1][0] + self.padding[1][1] 2181 else: 2182 cols = None 2183 return tensor_shape.TensorShape( 2184 [input_shape[0], rows, cols, input_shape[3]]) 2185 2186 def call(self, inputs): 2187 return backend.spatial_2d_padding( 2188 inputs, padding=self.padding, data_format=self.data_format) 2189 2190 def get_config(self): 2191 config = {'padding': self.padding, 'data_format': self.data_format} 2192 base_config = super(ZeroPadding2D, self).get_config() 2193 return dict(list(base_config.items()) + list(config.items())) 2194 2195 2196@keras_export('keras.layers.ZeroPadding3D') 2197class ZeroPadding3D(Layer): 2198 """Zero-padding layer for 3D data (spatial or spatio-temporal). 2199 2200 Arguments: 2201 padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. 2202 - If int: the same symmetric padding 2203 is applied to height and width. 2204 - If tuple of 3 ints: 2205 interpreted as two different 2206 symmetric padding values for height and width: 2207 `(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`. 2208 - If tuple of 3 tuples of 2 ints: 2209 interpreted as 2210 `((left_dim1_pad, right_dim1_pad), (left_dim2_pad, 2211 right_dim2_pad), (left_dim3_pad, right_dim3_pad))` 2212 data_format: A string, 2213 one of `channels_last` (default) or `channels_first`. 2214 The ordering of the dimensions in the inputs. 2215 `channels_last` corresponds to inputs with shape 2216 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 2217 while `channels_first` corresponds to inputs with shape 2218 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 2219 It defaults to the `image_data_format` value found in your 2220 Keras config file at `~/.keras/keras.json`. 2221 If you never set it, then it will be "channels_last". 2222 2223 Input shape: 2224 5D tensor with shape: 2225 - If `data_format` is `"channels_last"`: 2226 `(batch, first_axis_to_pad, second_axis_to_pad, third_axis_to_pad, 2227 depth)` 2228 - If `data_format` is `"channels_first"`: 2229 `(batch, depth, first_axis_to_pad, second_axis_to_pad, 2230 third_axis_to_pad)` 2231 2232 Output shape: 2233 5D tensor with shape: 2234 - If `data_format` is `"channels_last"`: 2235 `(batch, first_padded_axis, second_padded_axis, third_axis_to_pad, 2236 depth)` 2237 - If `data_format` is `"channels_first"`: 2238 `(batch, depth, first_padded_axis, second_padded_axis, 2239 third_axis_to_pad)` 2240 """ 2241 2242 def __init__(self, padding=(1, 1, 1), data_format=None, **kwargs): 2243 super(ZeroPadding3D, self).__init__(**kwargs) 2244 self.data_format = conv_utils.normalize_data_format(data_format) 2245 if isinstance(padding, int): 2246 self.padding = ((padding, padding), (padding, padding), (padding, 2247 padding)) 2248 elif hasattr(padding, '__len__'): 2249 if len(padding) != 3: 2250 raise ValueError('`padding` should have 3 elements. ' 2251 'Found: ' + str(padding)) 2252 dim1_padding = conv_utils.normalize_tuple(padding[0], 2, 2253 '1st entry of padding') 2254 dim2_padding = conv_utils.normalize_tuple(padding[1], 2, 2255 '2nd entry of padding') 2256 dim3_padding = conv_utils.normalize_tuple(padding[2], 2, 2257 '3rd entry of padding') 2258 self.padding = (dim1_padding, dim2_padding, dim3_padding) 2259 else: 2260 raise ValueError( 2261 '`padding` should be either an int, ' 2262 'a tuple of 3 ints ' 2263 '(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), ' 2264 'or a tuple of 3 tuples of 2 ints ' 2265 '((left_dim1_pad, right_dim1_pad),' 2266 ' (left_dim2_pad, right_dim2_pad),' 2267 ' (left_dim3_pad, right_dim2_pad)). ' 2268 'Found: ' + str(padding)) 2269 self.input_spec = InputSpec(ndim=5) 2270 2271 def compute_output_shape(self, input_shape): 2272 input_shape = tensor_shape.TensorShape(input_shape).as_list() 2273 if self.data_format == 'channels_first': 2274 if input_shape[2] is not None: 2275 dim1 = input_shape[2] + 2 * self.padding[0][0] 2276 else: 2277 dim1 = None 2278 if input_shape[3] is not None: 2279 dim2 = input_shape[3] + 2 * self.padding[1][0] 2280 else: 2281 dim2 = None 2282 if input_shape[4] is not None: 2283 dim3 = input_shape[4] + 2 * self.padding[2][0] 2284 else: 2285 dim3 = None 2286 return tensor_shape.TensorShape( 2287 [input_shape[0], input_shape[1], dim1, dim2, dim3]) 2288 elif self.data_format == 'channels_last': 2289 if input_shape[1] is not None: 2290 dim1 = input_shape[1] + 2 * self.padding[0][1] 2291 else: 2292 dim1 = None 2293 if input_shape[2] is not None: 2294 dim2 = input_shape[2] + 2 * self.padding[1][1] 2295 else: 2296 dim2 = None 2297 if input_shape[3] is not None: 2298 dim3 = input_shape[3] + 2 * self.padding[2][1] 2299 else: 2300 dim3 = None 2301 return tensor_shape.TensorShape( 2302 [input_shape[0], dim1, dim2, dim3, input_shape[4]]) 2303 2304 def call(self, inputs): 2305 return backend.spatial_3d_padding( 2306 inputs, padding=self.padding, data_format=self.data_format) 2307 2308 def get_config(self): 2309 config = {'padding': self.padding, 'data_format': self.data_format} 2310 base_config = super(ZeroPadding3D, self).get_config() 2311 return dict(list(base_config.items()) + list(config.items())) 2312 2313 2314@keras_export('keras.layers.Cropping1D') 2315class Cropping1D(Layer): 2316 """Cropping layer for 1D input (e.g. temporal sequence). 2317 2318 It crops along the time dimension (axis 1). 2319 2320 Arguments: 2321 cropping: Int or tuple of int (length 2) 2322 How many units should be trimmed off at the beginning and end of 2323 the cropping dimension (axis 1). 2324 If a single int is provided, the same value will be used for both. 2325 2326 Input shape: 2327 3D tensor with shape `(batch, axis_to_crop, features)` 2328 2329 Output shape: 2330 3D tensor with shape `(batch, cropped_axis, features)` 2331 """ 2332 2333 def __init__(self, cropping=(1, 1), **kwargs): 2334 super(Cropping1D, self).__init__(**kwargs) 2335 self.cropping = conv_utils.normalize_tuple(cropping, 2, 'cropping') 2336 self.input_spec = InputSpec(ndim=3) 2337 2338 def compute_output_shape(self, input_shape): 2339 input_shape = tensor_shape.TensorShape(input_shape).as_list() 2340 if input_shape[1] is not None: 2341 length = input_shape[1] - self.cropping[0] - self.cropping[1] 2342 else: 2343 length = None 2344 return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]]) 2345 2346 def call(self, inputs): 2347 if self.cropping[1] == 0: 2348 return inputs[:, self.cropping[0]:, :] 2349 else: 2350 return inputs[:, self.cropping[0]:-self.cropping[1], :] 2351 2352 def get_config(self): 2353 config = {'cropping': self.cropping} 2354 base_config = super(Cropping1D, self).get_config() 2355 return dict(list(base_config.items()) + list(config.items())) 2356 2357 2358@keras_export('keras.layers.Cropping2D') 2359class Cropping2D(Layer): 2360 """Cropping layer for 2D input (e.g. picture). 2361 2362 It crops along spatial dimensions, i.e. height and width. 2363 2364 Arguments: 2365 cropping: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. 2366 - If int: the same symmetric cropping 2367 is applied to height and width. 2368 - If tuple of 2 ints: 2369 interpreted as two different 2370 symmetric cropping values for height and width: 2371 `(symmetric_height_crop, symmetric_width_crop)`. 2372 - If tuple of 2 tuples of 2 ints: 2373 interpreted as 2374 `((top_crop, bottom_crop), (left_crop, right_crop))` 2375 data_format: A string, 2376 one of `channels_last` (default) or `channels_first`. 2377 The ordering of the dimensions in the inputs. 2378 `channels_last` corresponds to inputs with shape 2379 `(batch, height, width, channels)` while `channels_first` 2380 corresponds to inputs with shape 2381 `(batch, channels, height, width)`. 2382 It defaults to the `image_data_format` value found in your 2383 Keras config file at `~/.keras/keras.json`. 2384 If you never set it, then it will be "channels_last". 2385 2386 Input shape: 2387 4D tensor with shape: 2388 - If `data_format` is `"channels_last"`: 2389 `(batch, rows, cols, channels)` 2390 - If `data_format` is `"channels_first"`: 2391 `(batch, channels, rows, cols)` 2392 2393 Output shape: 2394 4D tensor with shape: 2395 - If `data_format` is `"channels_last"`: 2396 `(batch, cropped_rows, cropped_cols, channels)` 2397 - If `data_format` is `"channels_first"`: 2398 `(batch, channels, cropped_rows, cropped_cols)` 2399 2400 Examples: 2401 2402 ```python 2403 # Crop the input 2D images or feature maps 2404 model = Sequential() 2405 model.add(Cropping2D(cropping=((2, 2), (4, 4)), 2406 input_shape=(28, 28, 3))) 2407 # now model.output_shape == (None, 24, 20, 3) 2408 model.add(Conv2D(64, (3, 3), padding='same)) 2409 model.add(Cropping2D(cropping=((2, 2), (2, 2)))) 2410 # now model.output_shape == (None, 20, 16. 64) 2411 ``` 2412 """ 2413 2414 def __init__(self, cropping=((0, 0), (0, 0)), data_format=None, **kwargs): 2415 super(Cropping2D, self).__init__(**kwargs) 2416 self.data_format = conv_utils.normalize_data_format(data_format) 2417 if isinstance(cropping, int): 2418 self.cropping = ((cropping, cropping), (cropping, cropping)) 2419 elif hasattr(cropping, '__len__'): 2420 if len(cropping) != 2: 2421 raise ValueError('`cropping` should have two elements. ' 2422 'Found: ' + str(cropping)) 2423 height_cropping = conv_utils.normalize_tuple(cropping[0], 2, 2424 '1st entry of cropping') 2425 width_cropping = conv_utils.normalize_tuple(cropping[1], 2, 2426 '2nd entry of cropping') 2427 self.cropping = (height_cropping, width_cropping) 2428 else: 2429 raise ValueError('`cropping` should be either an int, ' 2430 'a tuple of 2 ints ' 2431 '(symmetric_height_crop, symmetric_width_crop), ' 2432 'or a tuple of 2 tuples of 2 ints ' 2433 '((top_crop, bottom_crop), (left_crop, right_crop)). ' 2434 'Found: ' + str(cropping)) 2435 self.input_spec = InputSpec(ndim=4) 2436 2437 def compute_output_shape(self, input_shape): 2438 input_shape = tensor_shape.TensorShape(input_shape).as_list() 2439 # pylint: disable=invalid-unary-operand-type 2440 if self.data_format == 'channels_first': 2441 return tensor_shape.TensorShape([ 2442 input_shape[0], input_shape[1], 2443 input_shape[2] - self.cropping[0][0] - self.cropping[0][1] 2444 if input_shape[2] else None, 2445 input_shape[3] - self.cropping[1][0] - self.cropping[1][1] 2446 if input_shape[3] else None 2447 ]) 2448 else: 2449 return tensor_shape.TensorShape([ 2450 input_shape[0], 2451 input_shape[1] - self.cropping[0][0] - self.cropping[0][1] 2452 if input_shape[1] else None, 2453 input_shape[2] - self.cropping[1][0] - self.cropping[1][1] 2454 if input_shape[2] else None, input_shape[3] 2455 ]) 2456 # pylint: enable=invalid-unary-operand-type 2457 2458 def call(self, inputs): 2459 # pylint: disable=invalid-unary-operand-type 2460 if self.data_format == 'channels_first': 2461 if self.cropping[0][1] == self.cropping[1][1] == 0: 2462 return inputs[:, :, self.cropping[0][0]:, self.cropping[1][0]:] 2463 elif self.cropping[0][1] == 0: 2464 return inputs[:, :, self.cropping[0][0]:, self.cropping[1][0]: 2465 -self.cropping[1][1]] 2466 elif self.cropping[1][1] == 0: 2467 return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1], 2468 self.cropping[1][0]:] 2469 return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1], 2470 self.cropping[1][0]:-self.cropping[1][1]] 2471 else: 2472 if self.cropping[0][1] == self.cropping[1][1] == 0: 2473 return inputs[:, self.cropping[0][0]:, self.cropping[1][0]:, :] 2474 elif self.cropping[0][1] == 0: 2475 return inputs[:, self.cropping[0][0]:, self.cropping[1][0]: 2476 -self.cropping[1][1], :] 2477 elif self.cropping[1][1] == 0: 2478 return inputs[:, self.cropping[0][0]:-self.cropping[0][1], 2479 self.cropping[1][0]:, :] 2480 return inputs[:, self.cropping[0][0]:-self.cropping[0][1], self.cropping[ 2481 1][0]:-self.cropping[1][1], :] # pylint: disable=invalid-unary-operand-type 2482 # pylint: enable=invalid-unary-operand-type 2483 2484 def get_config(self): 2485 config = {'cropping': self.cropping, 'data_format': self.data_format} 2486 base_config = super(Cropping2D, self).get_config() 2487 return dict(list(base_config.items()) + list(config.items())) 2488 2489 2490@keras_export('keras.layers.Cropping3D') 2491class Cropping3D(Layer): 2492 """Cropping layer for 3D data (e.g. spatial or spatio-temporal). 2493 2494 Arguments: 2495 cropping: Int, or tuple of 23ints, or tuple of 3 tuples of 2 ints. 2496 - If int: the same symmetric cropping 2497 is applied to depth, height, and width. 2498 - If tuple of 3 ints: interpreted as two different 2499 symmetric cropping values for depth, height, and width: 2500 `(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`. 2501 - If tuple of 3 tuples of 2 ints: interpreted as 2502 `((left_dim1_crop, right_dim1_crop), (left_dim2_crop, 2503 right_dim2_crop), (left_dim3_crop, right_dim3_crop))` 2504 data_format: A string, 2505 one of `channels_last` (default) or `channels_first`. 2506 The ordering of the dimensions in the inputs. 2507 `channels_last` corresponds to inputs with shape 2508 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 2509 while `channels_first` corresponds to inputs with shape 2510 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 2511 It defaults to the `image_data_format` value found in your 2512 Keras config file at `~/.keras/keras.json`. 2513 If you never set it, then it will be "channels_last". 2514 2515 Input shape: 2516 5D tensor with shape: 2517 - If `data_format` is `"channels_last"`: 2518 `(batch, first_axis_to_crop, second_axis_to_crop, third_axis_to_crop, 2519 depth)` 2520 - If `data_format` is `"channels_first"`: 2521 `(batch, depth, first_axis_to_crop, second_axis_to_crop, 2522 third_axis_to_crop)` 2523 2524 Output shape: 2525 5D tensor with shape: 2526 - If `data_format` is `"channels_last"`: 2527 `(batch, first_cropped_axis, second_cropped_axis, third_cropped_axis, 2528 depth)` 2529 - If `data_format` is `"channels_first"`: 2530 `(batch, depth, first_cropped_axis, second_cropped_axis, 2531 third_cropped_axis)` 2532 """ 2533 2534 def __init__(self, 2535 cropping=((1, 1), (1, 1), (1, 1)), 2536 data_format=None, 2537 **kwargs): 2538 super(Cropping3D, self).__init__(**kwargs) 2539 self.data_format = conv_utils.normalize_data_format(data_format) 2540 if isinstance(cropping, int): 2541 self.cropping = ((cropping, cropping), (cropping, cropping), (cropping, 2542 cropping)) 2543 elif hasattr(cropping, '__len__'): 2544 if len(cropping) != 3: 2545 raise ValueError('`cropping` should have 3 elements. ' 2546 'Found: ' + str(cropping)) 2547 dim1_cropping = conv_utils.normalize_tuple(cropping[0], 2, 2548 '1st entry of cropping') 2549 dim2_cropping = conv_utils.normalize_tuple(cropping[1], 2, 2550 '2nd entry of cropping') 2551 dim3_cropping = conv_utils.normalize_tuple(cropping[2], 2, 2552 '3rd entry of cropping') 2553 self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping) 2554 else: 2555 raise ValueError( 2556 '`cropping` should be either an int, ' 2557 'a tuple of 3 ints ' 2558 '(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop), ' 2559 'or a tuple of 3 tuples of 2 ints ' 2560 '((left_dim1_crop, right_dim1_crop),' 2561 ' (left_dim2_crop, right_dim2_crop),' 2562 ' (left_dim3_crop, right_dim2_crop)). ' 2563 'Found: ' + str(cropping)) 2564 self.input_spec = InputSpec(ndim=5) 2565 2566 def compute_output_shape(self, input_shape): 2567 input_shape = tensor_shape.TensorShape(input_shape).as_list() 2568 # pylint: disable=invalid-unary-operand-type 2569 if self.data_format == 'channels_first': 2570 if input_shape[2] is not None: 2571 dim1 = input_shape[2] - self.cropping[0][0] - self.cropping[0][1] 2572 else: 2573 dim1 = None 2574 if input_shape[3] is not None: 2575 dim2 = input_shape[3] - self.cropping[1][0] - self.cropping[1][1] 2576 else: 2577 dim2 = None 2578 if input_shape[4] is not None: 2579 dim3 = input_shape[4] - self.cropping[2][0] - self.cropping[2][1] 2580 else: 2581 dim3 = None 2582 return tensor_shape.TensorShape( 2583 [input_shape[0], input_shape[1], dim1, dim2, dim3]) 2584 elif self.data_format == 'channels_last': 2585 if input_shape[1] is not None: 2586 dim1 = input_shape[1] - self.cropping[0][0] - self.cropping[0][1] 2587 else: 2588 dim1 = None 2589 if input_shape[2] is not None: 2590 dim2 = input_shape[2] - self.cropping[1][0] - self.cropping[1][1] 2591 else: 2592 dim2 = None 2593 if input_shape[3] is not None: 2594 dim3 = input_shape[3] - self.cropping[2][0] - self.cropping[2][1] 2595 else: 2596 dim3 = None 2597 return tensor_shape.TensorShape( 2598 [input_shape[0], dim1, dim2, dim3, input_shape[4]]) 2599 # pylint: enable=invalid-unary-operand-type 2600 2601 def call(self, inputs): 2602 # pylint: disable=invalid-unary-operand-type 2603 if self.data_format == 'channels_first': 2604 if self.cropping[0][1] == self.cropping[1][1] == self.cropping[2][1] == 0: 2605 return inputs[:, :, self.cropping[0][0]:, self.cropping[1][0]:, 2606 self.cropping[2][0]:] 2607 elif self.cropping[0][1] == self.cropping[1][1] == 0: 2608 return inputs[:, :, self.cropping[0][0]:, self.cropping[1][0]:, 2609 self.cropping[2][0]:-self.cropping[2][1]] 2610 elif self.cropping[1][1] == self.cropping[2][1] == 0: 2611 return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1], 2612 self.cropping[1][0]:, self.cropping[2][0]:] 2613 elif self.cropping[0][1] == self.cropping[2][1] == 0: 2614 return inputs[:, :, self.cropping[0][0]:, self.cropping[1][0]: 2615 -self.cropping[1][1], self.cropping[2][0]:] 2616 elif self.cropping[0][1] == 0: 2617 return inputs[:, :, self.cropping[0][0]:, self.cropping[1][ 2618 0]:-self.cropping[1][1], self.cropping[2][0]:-self.cropping[2][1]] 2619 elif self.cropping[1][1] == 0: 2620 return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1], self. 2621 cropping[1][0]:, self.cropping[2][0]:-self.cropping[2][1]] 2622 elif self.cropping[2][1] == 0: 2623 return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1], self. 2624 cropping[1][0]:-self.cropping[1][1], self.cropping[2][0]:] 2625 return inputs[:, :, self.cropping[0][0]:-self.cropping[0][1], 2626 self.cropping[1][0]:-self.cropping[1][1], self.cropping[2][ 2627 0]:-self.cropping[2][1]] 2628 else: 2629 if self.cropping[0][1] == self.cropping[1][1] == self.cropping[2][1] == 0: 2630 return inputs[:, self.cropping[0][0]:, self.cropping[1][0]:, 2631 self.cropping[2][0]:, :] 2632 elif self.cropping[0][1] == self.cropping[1][1] == 0: 2633 return inputs[:, self.cropping[0][0]:, self.cropping[1][0]:, 2634 self.cropping[2][0]:-self.cropping[2][1], :] 2635 elif self.cropping[1][1] == self.cropping[2][1] == 0: 2636 return inputs[:, self.cropping[0][0]:-self.cropping[0][1], 2637 self.cropping[1][0]:, self.cropping[2][0]:, :] 2638 elif self.cropping[0][1] == self.cropping[2][1] == 0: 2639 return inputs[:, self.cropping[0][0]:, self.cropping[1][0]: 2640 -self.cropping[1][1], self.cropping[2][0]:, :] 2641 elif self.cropping[0][1] == 0: 2642 return inputs[:, self.cropping[0][0]:, self.cropping[1][ 2643 0]:-self.cropping[1][1], self.cropping[2][0]: 2644 -self.cropping[2][1], :] 2645 elif self.cropping[1][1] == 0: 2646 return inputs[:, self.cropping[0][ 2647 0]:-self.cropping[0][1], self.cropping[1][0]:, self.cropping[2][0]: 2648 -self.cropping[2][1], :] 2649 elif self.cropping[2][1] == 0: 2650 return inputs[:, self.cropping[0][0]:-self.cropping[0][1], 2651 self.cropping[1][0]:-self.cropping[1][1], self.cropping[ 2652 2][0]:, :] 2653 return inputs[:, self.cropping[0][0]:-self.cropping[0][1], self.cropping[ 2654 1][0]:-self.cropping[1][1], self.cropping[2][0]: # pylint: disable=invalid-unary-operand-type 2655 -self.cropping[2][1], :] # pylint: disable=invalid-unary-operand-type 2656 # pylint: enable=invalid-unary-operand-type 2657 2658 def get_config(self): 2659 config = {'cropping': self.cropping, 'data_format': self.data_format} 2660 base_config = super(Cropping3D, self).get_config() 2661 return dict(list(base_config.items()) + list(config.items())) 2662 2663 2664# Aliases 2665 2666Convolution1D = Conv1D 2667Convolution2D = Conv2D 2668Convolution3D = Conv3D 2669SeparableConvolution1D = SeparableConv1D 2670SeparableConvolution2D = SeparableConv2D 2671Convolution2DTranspose = Conv2DTranspose 2672Convolution3DTranspose = Conv3DTranspose 2673Deconvolution2D = Deconv2D = Conv2DTranspose 2674Deconvolution3D = Deconv3D = Conv3DTranspose 2675