1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Pooling layers.""" 16 17import functools 18 19from tensorflow.python.framework import tensor_shape 20from tensorflow.python.keras import backend 21from tensorflow.python.keras.engine.base_layer import Layer 22from tensorflow.python.keras.engine.input_spec import InputSpec 23from tensorflow.python.keras.utils import conv_utils 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn 27from tensorflow.python.util.tf_export import keras_export 28 29 30class Pooling1D(Layer): 31 """Pooling layer for arbitrary pooling functions, for 1D inputs. 32 33 This class only exists for code reuse. It will never be an exposed API. 34 35 Args: 36 pool_function: The pooling function to apply, e.g. `tf.nn.max_pool2d`. 37 pool_size: An integer or tuple/list of a single integer, 38 representing the size of the pooling window. 39 strides: An integer or tuple/list of a single integer, specifying the 40 strides of the pooling operation. 41 padding: A string. The padding method, either 'valid' or 'same'. 42 Case-insensitive. 43 data_format: A string, 44 one of `channels_last` (default) or `channels_first`. 45 The ordering of the dimensions in the inputs. 46 `channels_last` corresponds to inputs with shape 47 `(batch, steps, features)` while `channels_first` 48 corresponds to inputs with shape 49 `(batch, features, steps)`. 50 name: A string, the name of the layer. 51 """ 52 53 def __init__(self, pool_function, pool_size, strides, 54 padding='valid', data_format='channels_last', 55 name=None, **kwargs): 56 super(Pooling1D, self).__init__(name=name, **kwargs) 57 if data_format is None: 58 data_format = backend.image_data_format() 59 if strides is None: 60 strides = pool_size 61 self.pool_function = pool_function 62 self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size') 63 self.strides = conv_utils.normalize_tuple(strides, 1, 'strides') 64 self.padding = conv_utils.normalize_padding(padding) 65 self.data_format = conv_utils.normalize_data_format(data_format) 66 self.input_spec = InputSpec(ndim=3) 67 68 def call(self, inputs): 69 pad_axis = 2 if self.data_format == 'channels_last' else 3 70 inputs = array_ops.expand_dims(inputs, pad_axis) 71 outputs = self.pool_function( 72 inputs, 73 self.pool_size + (1,), 74 strides=self.strides + (1,), 75 padding=self.padding, 76 data_format=self.data_format) 77 return array_ops.squeeze(outputs, pad_axis) 78 79 def compute_output_shape(self, input_shape): 80 input_shape = tensor_shape.TensorShape(input_shape).as_list() 81 if self.data_format == 'channels_first': 82 steps = input_shape[2] 83 features = input_shape[1] 84 else: 85 steps = input_shape[1] 86 features = input_shape[2] 87 length = conv_utils.conv_output_length(steps, 88 self.pool_size[0], 89 self.padding, 90 self.strides[0]) 91 if self.data_format == 'channels_first': 92 return tensor_shape.TensorShape([input_shape[0], features, length]) 93 else: 94 return tensor_shape.TensorShape([input_shape[0], length, features]) 95 96 def get_config(self): 97 config = { 98 'strides': self.strides, 99 'pool_size': self.pool_size, 100 'padding': self.padding, 101 'data_format': self.data_format, 102 } 103 base_config = super(Pooling1D, self).get_config() 104 return dict(list(base_config.items()) + list(config.items())) 105 106 107@keras_export('keras.layers.MaxPool1D', 'keras.layers.MaxPooling1D') 108class MaxPooling1D(Pooling1D): 109 """Max pooling operation for 1D temporal data. 110 111 Downsamples the input representation by taking the maximum value over a 112 spatial window of size `pool_size`. The window is shifted by `strides`. The 113 resulting output, when using the `"valid"` padding option, has a shape of: 114 `output_shape = (input_shape - pool_size + 1) / strides)` 115 116 The resulting output shape when using the `"same"` padding option is: 117 `output_shape = input_shape / strides` 118 119 For example, for `strides=1` and `padding="valid"`: 120 121 >>> x = tf.constant([1., 2., 3., 4., 5.]) 122 >>> x = tf.reshape(x, [1, 5, 1]) 123 >>> max_pool_1d = tf.keras.layers.MaxPooling1D(pool_size=2, 124 ... strides=1, padding='valid') 125 >>> max_pool_1d(x) 126 <tf.Tensor: shape=(1, 4, 1), dtype=float32, numpy= 127 array([[[2.], 128 [3.], 129 [4.], 130 [5.]]], dtype=float32)> 131 132 For example, for `strides=2` and `padding="valid"`: 133 134 >>> x = tf.constant([1., 2., 3., 4., 5.]) 135 >>> x = tf.reshape(x, [1, 5, 1]) 136 >>> max_pool_1d = tf.keras.layers.MaxPooling1D(pool_size=2, 137 ... strides=2, padding='valid') 138 >>> max_pool_1d(x) 139 <tf.Tensor: shape=(1, 2, 1), dtype=float32, numpy= 140 array([[[2.], 141 [4.]]], dtype=float32)> 142 143 For example, for `strides=1` and `padding="same"`: 144 145 >>> x = tf.constant([1., 2., 3., 4., 5.]) 146 >>> x = tf.reshape(x, [1, 5, 1]) 147 >>> max_pool_1d = tf.keras.layers.MaxPooling1D(pool_size=2, 148 ... strides=1, padding='same') 149 >>> max_pool_1d(x) 150 <tf.Tensor: shape=(1, 5, 1), dtype=float32, numpy= 151 array([[[2.], 152 [3.], 153 [4.], 154 [5.], 155 [5.]]], dtype=float32)> 156 157 Args: 158 pool_size: Integer, size of the max pooling window. 159 strides: Integer, or None. Specifies how much the pooling window moves 160 for each pooling step. 161 If None, it will default to `pool_size`. 162 padding: One of `"valid"` or `"same"` (case-insensitive). 163 `"valid"` means no padding. `"same"` results in padding evenly to 164 the left/right or up/down of the input such that output has the same 165 height/width dimension as the input. 166 data_format: A string, 167 one of `channels_last` (default) or `channels_first`. 168 The ordering of the dimensions in the inputs. 169 `channels_last` corresponds to inputs with shape 170 `(batch, steps, features)` while `channels_first` 171 corresponds to inputs with shape 172 `(batch, features, steps)`. 173 174 Input shape: 175 - If `data_format='channels_last'`: 176 3D tensor with shape `(batch_size, steps, features)`. 177 - If `data_format='channels_first'`: 178 3D tensor with shape `(batch_size, features, steps)`. 179 180 Output shape: 181 - If `data_format='channels_last'`: 182 3D tensor with shape `(batch_size, downsampled_steps, features)`. 183 - If `data_format='channels_first'`: 184 3D tensor with shape `(batch_size, features, downsampled_steps)`. 185 """ 186 187 def __init__(self, pool_size=2, strides=None, 188 padding='valid', data_format='channels_last', **kwargs): 189 190 super(MaxPooling1D, self).__init__( 191 functools.partial(backend.pool2d, pool_mode='max'), 192 pool_size=pool_size, 193 strides=strides, 194 padding=padding, 195 data_format=data_format, 196 **kwargs) 197 198 199@keras_export('keras.layers.AveragePooling1D', 'keras.layers.AvgPool1D') 200class AveragePooling1D(Pooling1D): 201 """Average pooling for temporal data. 202 203 Downsamples the input representation by taking the average value over the 204 window defined by `pool_size`. The window is shifted by `strides`. The 205 resulting output when using "valid" padding option has a shape of: 206 `output_shape = (input_shape - pool_size + 1) / strides)` 207 208 The resulting output shape when using the "same" padding option is: 209 `output_shape = input_shape / strides` 210 211 For example, for strides=1 and padding="valid": 212 213 >>> x = tf.constant([1., 2., 3., 4., 5.]) 214 >>> x = tf.reshape(x, [1, 5, 1]) 215 >>> x 216 <tf.Tensor: shape=(1, 5, 1), dtype=float32, numpy= 217 array([[[1.], 218 [2.], 219 [3.], 220 [4.], 221 [5.]], dtype=float32)> 222 >>> avg_pool_1d = tf.keras.layers.AveragePooling1D(pool_size=2, 223 ... strides=1, padding='valid') 224 >>> avg_pool_1d(x) 225 <tf.Tensor: shape=(1, 4, 1), dtype=float32, numpy= 226 array([[[1.5], 227 [2.5], 228 [3.5], 229 [4.5]]], dtype=float32)> 230 231 For example, for strides=2 and padding="valid": 232 233 >>> x = tf.constant([1., 2., 3., 4., 5.]) 234 >>> x = tf.reshape(x, [1, 5, 1]) 235 >>> x 236 <tf.Tensor: shape=(1, 5, 1), dtype=float32, numpy= 237 array([[[1.], 238 [2.], 239 [3.], 240 [4.], 241 [5.]], dtype=float32)> 242 >>> avg_pool_1d = tf.keras.layers.AveragePooling1D(pool_size=2, 243 ... strides=2, padding='valid') 244 >>> avg_pool_1d(x) 245 <tf.Tensor: shape=(1, 2, 1), dtype=float32, numpy= 246 array([[[1.5], 247 [3.5]]], dtype=float32)> 248 249 For example, for strides=1 and padding="same": 250 251 >>> x = tf.constant([1., 2., 3., 4., 5.]) 252 >>> x = tf.reshape(x, [1, 5, 1]) 253 >>> x 254 <tf.Tensor: shape=(1, 5, 1), dtype=float32, numpy= 255 array([[[1.], 256 [2.], 257 [3.], 258 [4.], 259 [5.]], dtype=float32)> 260 >>> avg_pool_1d = tf.keras.layers.AveragePooling1D(pool_size=2, 261 ... strides=1, padding='same') 262 >>> avg_pool_1d(x) 263 <tf.Tensor: shape=(1, 5, 1), dtype=float32, numpy= 264 array([[[1.5], 265 [2.5], 266 [3.5], 267 [4.5], 268 [5.]]], dtype=float32)> 269 270 Args: 271 pool_size: Integer, size of the average pooling windows. 272 strides: Integer, or None. Factor by which to downscale. 273 E.g. 2 will halve the input. 274 If None, it will default to `pool_size`. 275 padding: One of `"valid"` or `"same"` (case-insensitive). 276 `"valid"` means no padding. `"same"` results in padding evenly to 277 the left/right or up/down of the input such that output has the same 278 height/width dimension as the input. 279 data_format: A string, 280 one of `channels_last` (default) or `channels_first`. 281 The ordering of the dimensions in the inputs. 282 `channels_last` corresponds to inputs with shape 283 `(batch, steps, features)` while `channels_first` 284 corresponds to inputs with shape 285 `(batch, features, steps)`. 286 287 Input shape: 288 - If `data_format='channels_last'`: 289 3D tensor with shape `(batch_size, steps, features)`. 290 - If `data_format='channels_first'`: 291 3D tensor with shape `(batch_size, features, steps)`. 292 293 Output shape: 294 - If `data_format='channels_last'`: 295 3D tensor with shape `(batch_size, downsampled_steps, features)`. 296 - If `data_format='channels_first'`: 297 3D tensor with shape `(batch_size, features, downsampled_steps)`. 298 """ 299 300 def __init__(self, pool_size=2, strides=None, 301 padding='valid', data_format='channels_last', **kwargs): 302 super(AveragePooling1D, self).__init__( 303 functools.partial(backend.pool2d, pool_mode='avg'), 304 pool_size=pool_size, 305 strides=strides, 306 padding=padding, 307 data_format=data_format, 308 **kwargs) 309 310 311class Pooling2D(Layer): 312 """Pooling layer for arbitrary pooling functions, for 2D inputs (e.g. images). 313 314 This class only exists for code reuse. It will never be an exposed API. 315 316 Args: 317 pool_function: The pooling function to apply, e.g. `tf.nn.max_pool2d`. 318 pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width) 319 specifying the size of the pooling window. 320 Can be a single integer to specify the same value for 321 all spatial dimensions. 322 strides: An integer or tuple/list of 2 integers, 323 specifying the strides of the pooling operation. 324 Can be a single integer to specify the same value for 325 all spatial dimensions. 326 padding: A string. The padding method, either 'valid' or 'same'. 327 Case-insensitive. 328 data_format: A string, one of `channels_last` (default) or `channels_first`. 329 The ordering of the dimensions in the inputs. 330 `channels_last` corresponds to inputs with shape 331 `(batch, height, width, channels)` while `channels_first` corresponds to 332 inputs with shape `(batch, channels, height, width)`. 333 name: A string, the name of the layer. 334 """ 335 336 def __init__(self, pool_function, pool_size, strides, 337 padding='valid', data_format=None, 338 name=None, **kwargs): 339 super(Pooling2D, self).__init__(name=name, **kwargs) 340 if data_format is None: 341 data_format = backend.image_data_format() 342 if strides is None: 343 strides = pool_size 344 self.pool_function = pool_function 345 self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size') 346 self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 347 self.padding = conv_utils.normalize_padding(padding) 348 self.data_format = conv_utils.normalize_data_format(data_format) 349 self.input_spec = InputSpec(ndim=4) 350 351 def call(self, inputs): 352 if self.data_format == 'channels_last': 353 pool_shape = (1,) + self.pool_size + (1,) 354 strides = (1,) + self.strides + (1,) 355 else: 356 pool_shape = (1, 1) + self.pool_size 357 strides = (1, 1) + self.strides 358 outputs = self.pool_function( 359 inputs, 360 ksize=pool_shape, 361 strides=strides, 362 padding=self.padding.upper(), 363 data_format=conv_utils.convert_data_format(self.data_format, 4)) 364 return outputs 365 366 def compute_output_shape(self, input_shape): 367 input_shape = tensor_shape.TensorShape(input_shape).as_list() 368 if self.data_format == 'channels_first': 369 rows = input_shape[2] 370 cols = input_shape[3] 371 else: 372 rows = input_shape[1] 373 cols = input_shape[2] 374 rows = conv_utils.conv_output_length(rows, self.pool_size[0], self.padding, 375 self.strides[0]) 376 cols = conv_utils.conv_output_length(cols, self.pool_size[1], self.padding, 377 self.strides[1]) 378 if self.data_format == 'channels_first': 379 return tensor_shape.TensorShape( 380 [input_shape[0], input_shape[1], rows, cols]) 381 else: 382 return tensor_shape.TensorShape( 383 [input_shape[0], rows, cols, input_shape[3]]) 384 385 def get_config(self): 386 config = { 387 'pool_size': self.pool_size, 388 'padding': self.padding, 389 'strides': self.strides, 390 'data_format': self.data_format 391 } 392 base_config = super(Pooling2D, self).get_config() 393 return dict(list(base_config.items()) + list(config.items())) 394 395 396@keras_export('keras.layers.MaxPool2D', 'keras.layers.MaxPooling2D') 397class MaxPooling2D(Pooling2D): 398 """Max pooling operation for 2D spatial data. 399 400 Downsamples the input along its spatial dimensions (height and width) 401 by taking the maximum value over an input window 402 (of size defined by `pool_size`) for each channel of the input. 403 The window is shifted by `strides` along each dimension. 404 405 The resulting output, 406 when using the `"valid"` padding option, has a spatial shape 407 (number of rows or columns) of: 408 `output_shape = math.floor((input_shape - pool_size) / strides) + 1` 409 (when `input_shape >= pool_size`) 410 411 The resulting output shape when using the `"same"` padding option is: 412 `output_shape = math.floor((input_shape - 1) / strides) + 1` 413 414 For example, for `strides=(1, 1)` and `padding="valid"`: 415 416 >>> x = tf.constant([[1., 2., 3.], 417 ... [4., 5., 6.], 418 ... [7., 8., 9.]]) 419 >>> x = tf.reshape(x, [1, 3, 3, 1]) 420 >>> max_pool_2d = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), 421 ... strides=(1, 1), padding='valid') 422 >>> max_pool_2d(x) 423 <tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy= 424 array([[[[5.], 425 [6.]], 426 [[8.], 427 [9.]]]], dtype=float32)> 428 429 For example, for `strides=(2, 2)` and `padding="valid"`: 430 431 >>> x = tf.constant([[1., 2., 3., 4.], 432 ... [5., 6., 7., 8.], 433 ... [9., 10., 11., 12.]]) 434 >>> x = tf.reshape(x, [1, 3, 4, 1]) 435 >>> max_pool_2d = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), 436 ... strides=(2, 2), padding='valid') 437 >>> max_pool_2d(x) 438 <tf.Tensor: shape=(1, 1, 2, 1), dtype=float32, numpy= 439 array([[[[6.], 440 [8.]]]], dtype=float32)> 441 442 Usage Example: 443 444 >>> input_image = tf.constant([[[[1.], [1.], [2.], [4.]], 445 ... [[2.], [2.], [3.], [2.]], 446 ... [[4.], [1.], [1.], [1.]], 447 ... [[2.], [2.], [1.], [4.]]]]) 448 >>> output = tf.constant([[[[1], [0]], 449 ... [[0], [1]]]]) 450 >>> model = tf.keras.models.Sequential() 451 >>> model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), 452 ... input_shape=(4, 4, 1))) 453 >>> model.compile('adam', 'mean_squared_error') 454 >>> model.predict(input_image, steps=1) 455 array([[[[2.], 456 [4.]], 457 [[4.], 458 [4.]]]], dtype=float32) 459 460 For example, for stride=(1, 1) and padding="same": 461 462 >>> x = tf.constant([[1., 2., 3.], 463 ... [4., 5., 6.], 464 ... [7., 8., 9.]]) 465 >>> x = tf.reshape(x, [1, 3, 3, 1]) 466 >>> max_pool_2d = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), 467 ... strides=(1, 1), padding='same') 468 >>> max_pool_2d(x) 469 <tf.Tensor: shape=(1, 3, 3, 1), dtype=float32, numpy= 470 array([[[[5.], 471 [6.], 472 [6.]], 473 [[8.], 474 [9.], 475 [9.]], 476 [[8.], 477 [9.], 478 [9.]]]], dtype=float32)> 479 480 Args: 481 pool_size: integer or tuple of 2 integers, 482 window size over which to take the maximum. 483 `(2, 2)` will take the max value over a 2x2 pooling window. 484 If only one integer is specified, the same window length 485 will be used for both dimensions. 486 strides: Integer, tuple of 2 integers, or None. 487 Strides values. Specifies how far the pooling window moves 488 for each pooling step. If None, it will default to `pool_size`. 489 padding: One of `"valid"` or `"same"` (case-insensitive). 490 `"valid"` means no padding. `"same"` results in padding evenly to 491 the left/right or up/down of the input such that output has the same 492 height/width dimension as the input. 493 data_format: A string, 494 one of `channels_last` (default) or `channels_first`. 495 The ordering of the dimensions in the inputs. 496 `channels_last` corresponds to inputs with shape 497 `(batch, height, width, channels)` while `channels_first` 498 corresponds to inputs with shape 499 `(batch, channels, height, width)`. 500 It defaults to the `image_data_format` value found in your 501 Keras config file at `~/.keras/keras.json`. 502 If you never set it, then it will be "channels_last". 503 504 Input shape: 505 - If `data_format='channels_last'`: 506 4D tensor with shape `(batch_size, rows, cols, channels)`. 507 - If `data_format='channels_first'`: 508 4D tensor with shape `(batch_size, channels, rows, cols)`. 509 510 Output shape: 511 - If `data_format='channels_last'`: 512 4D tensor with shape `(batch_size, pooled_rows, pooled_cols, channels)`. 513 - If `data_format='channels_first'`: 514 4D tensor with shape `(batch_size, channels, pooled_rows, pooled_cols)`. 515 516 Returns: 517 A tensor of rank 4 representing the maximum pooled values. See above for 518 output shape. 519 """ 520 521 def __init__(self, 522 pool_size=(2, 2), 523 strides=None, 524 padding='valid', 525 data_format=None, 526 **kwargs): 527 super(MaxPooling2D, self).__init__( 528 nn.max_pool, 529 pool_size=pool_size, strides=strides, 530 padding=padding, data_format=data_format, **kwargs) 531 532 533@keras_export('keras.layers.AveragePooling2D', 'keras.layers.AvgPool2D') 534class AveragePooling2D(Pooling2D): 535 """Average pooling operation for spatial data. 536 537 Downsamples the input along its spatial dimensions (height and width) 538 by taking the average value over an input window 539 (of size defined by `pool_size`) for each channel of the input. 540 The window is shifted by `strides` along each dimension. 541 542 The resulting output when using `"valid"` padding option has a shape 543 (number of rows or columns) of: 544 `output_shape = math.floor((input_shape - pool_size) / strides) + 1` 545 (when `input_shape >= pool_size`) 546 547 The resulting output shape when using the `"same"` padding option is: 548 `output_shape = math.floor((input_shape - 1) / strides) + 1` 549 550 For example, for `strides=(1, 1)` and `padding="valid"`: 551 552 >>> x = tf.constant([[1., 2., 3.], 553 ... [4., 5., 6.], 554 ... [7., 8., 9.]]) 555 >>> x = tf.reshape(x, [1, 3, 3, 1]) 556 >>> avg_pool_2d = tf.keras.layers.AveragePooling2D(pool_size=(2, 2), 557 ... strides=(1, 1), padding='valid') 558 >>> avg_pool_2d(x) 559 <tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy= 560 array([[[[3.], 561 [4.]], 562 [[6.], 563 [7.]]]], dtype=float32)> 564 565 For example, for `stride=(2, 2)` and `padding="valid"`: 566 567 >>> x = tf.constant([[1., 2., 3., 4.], 568 ... [5., 6., 7., 8.], 569 ... [9., 10., 11., 12.]]) 570 >>> x = tf.reshape(x, [1, 3, 4, 1]) 571 >>> avg_pool_2d = tf.keras.layers.AveragePooling2D(pool_size=(2, 2), 572 ... strides=(2, 2), padding='valid') 573 >>> avg_pool_2d(x) 574 <tf.Tensor: shape=(1, 1, 2, 1), dtype=float32, numpy= 575 array([[[[3.5], 576 [5.5]]]], dtype=float32)> 577 578 For example, for `strides=(1, 1)` and `padding="same"`: 579 580 >>> x = tf.constant([[1., 2., 3.], 581 ... [4., 5., 6.], 582 ... [7., 8., 9.]]) 583 >>> x = tf.reshape(x, [1, 3, 3, 1]) 584 >>> avg_pool_2d = tf.keras.layers.AveragePooling2D(pool_size=(2, 2), 585 ... strides=(1, 1), padding='same') 586 >>> avg_pool_2d(x) 587 <tf.Tensor: shape=(1, 3, 3, 1), dtype=float32, numpy= 588 array([[[[3.], 589 [4.], 590 [4.5]], 591 [[6.], 592 [7.], 593 [7.5]], 594 [[7.5], 595 [8.5], 596 [9.]]]], dtype=float32)> 597 598 Args: 599 pool_size: integer or tuple of 2 integers, 600 factors by which to downscale (vertical, horizontal). 601 `(2, 2)` will halve the input in both spatial dimension. 602 If only one integer is specified, the same window length 603 will be used for both dimensions. 604 strides: Integer, tuple of 2 integers, or None. 605 Strides values. 606 If None, it will default to `pool_size`. 607 padding: One of `"valid"` or `"same"` (case-insensitive). 608 `"valid"` means no padding. `"same"` results in padding evenly to 609 the left/right or up/down of the input such that output has the same 610 height/width dimension as the input. 611 data_format: A string, 612 one of `channels_last` (default) or `channels_first`. 613 The ordering of the dimensions in the inputs. 614 `channels_last` corresponds to inputs with shape 615 `(batch, height, width, channels)` while `channels_first` 616 corresponds to inputs with shape 617 `(batch, channels, height, width)`. 618 It defaults to the `image_data_format` value found in your 619 Keras config file at `~/.keras/keras.json`. 620 If you never set it, then it will be "channels_last". 621 622 Input shape: 623 - If `data_format='channels_last'`: 624 4D tensor with shape `(batch_size, rows, cols, channels)`. 625 - If `data_format='channels_first'`: 626 4D tensor with shape `(batch_size, channels, rows, cols)`. 627 628 Output shape: 629 - If `data_format='channels_last'`: 630 4D tensor with shape `(batch_size, pooled_rows, pooled_cols, channels)`. 631 - If `data_format='channels_first'`: 632 4D tensor with shape `(batch_size, channels, pooled_rows, pooled_cols)`. 633 """ 634 635 def __init__(self, 636 pool_size=(2, 2), 637 strides=None, 638 padding='valid', 639 data_format=None, 640 **kwargs): 641 super(AveragePooling2D, self).__init__( 642 nn.avg_pool, 643 pool_size=pool_size, strides=strides, 644 padding=padding, data_format=data_format, **kwargs) 645 646 647class Pooling3D(Layer): 648 """Pooling layer for arbitrary pooling functions, for 3D inputs. 649 650 This class only exists for code reuse. It will never be an exposed API. 651 652 Args: 653 pool_function: The pooling function to apply, e.g. `tf.nn.max_pool2d`. 654 pool_size: An integer or tuple/list of 3 integers: 655 (pool_depth, pool_height, pool_width) 656 specifying the size of the pooling window. 657 Can be a single integer to specify the same value for 658 all spatial dimensions. 659 strides: An integer or tuple/list of 3 integers, 660 specifying the strides of the pooling operation. 661 Can be a single integer to specify the same value for 662 all spatial dimensions. 663 padding: A string. The padding method, either 'valid' or 'same'. 664 Case-insensitive. 665 data_format: A string, one of `channels_last` (default) or `channels_first`. 666 The ordering of the dimensions in the inputs. 667 `channels_last` corresponds to inputs with shape 668 `(batch, depth, height, width, channels)` 669 while `channels_first` corresponds to 670 inputs with shape `(batch, channels, depth, height, width)`. 671 name: A string, the name of the layer. 672 """ 673 674 def __init__(self, pool_function, pool_size, strides, 675 padding='valid', data_format='channels_last', 676 name=None, **kwargs): 677 super(Pooling3D, self).__init__(name=name, **kwargs) 678 if data_format is None: 679 data_format = backend.image_data_format() 680 if strides is None: 681 strides = pool_size 682 self.pool_function = pool_function 683 self.pool_size = conv_utils.normalize_tuple(pool_size, 3, 'pool_size') 684 self.strides = conv_utils.normalize_tuple(strides, 3, 'strides') 685 self.padding = conv_utils.normalize_padding(padding) 686 self.data_format = conv_utils.normalize_data_format(data_format) 687 self.input_spec = InputSpec(ndim=5) 688 689 def call(self, inputs): 690 pool_shape = (1,) + self.pool_size + (1,) 691 strides = (1,) + self.strides + (1,) 692 693 if self.data_format == 'channels_first': 694 # TF does not support `channels_first` with 3D pooling operations, 695 # so we must handle this case manually. 696 # TODO(fchollet): remove this when TF pooling is feature-complete. 697 inputs = array_ops.transpose(inputs, (0, 2, 3, 4, 1)) 698 699 outputs = self.pool_function( 700 inputs, 701 ksize=pool_shape, 702 strides=strides, 703 padding=self.padding.upper()) 704 705 if self.data_format == 'channels_first': 706 outputs = array_ops.transpose(outputs, (0, 4, 1, 2, 3)) 707 return outputs 708 709 def compute_output_shape(self, input_shape): 710 input_shape = tensor_shape.TensorShape(input_shape).as_list() 711 if self.data_format == 'channels_first': 712 len_dim1 = input_shape[2] 713 len_dim2 = input_shape[3] 714 len_dim3 = input_shape[4] 715 else: 716 len_dim1 = input_shape[1] 717 len_dim2 = input_shape[2] 718 len_dim3 = input_shape[3] 719 len_dim1 = conv_utils.conv_output_length(len_dim1, self.pool_size[0], 720 self.padding, self.strides[0]) 721 len_dim2 = conv_utils.conv_output_length(len_dim2, self.pool_size[1], 722 self.padding, self.strides[1]) 723 len_dim3 = conv_utils.conv_output_length(len_dim3, self.pool_size[2], 724 self.padding, self.strides[2]) 725 if self.data_format == 'channels_first': 726 return tensor_shape.TensorShape( 727 [input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3]) 728 else: 729 return tensor_shape.TensorShape( 730 [input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]]) 731 732 def get_config(self): 733 config = { 734 'pool_size': self.pool_size, 735 'padding': self.padding, 736 'strides': self.strides, 737 'data_format': self.data_format 738 } 739 base_config = super(Pooling3D, self).get_config() 740 return dict(list(base_config.items()) + list(config.items())) 741 742 743@keras_export('keras.layers.MaxPool3D', 'keras.layers.MaxPooling3D') 744class MaxPooling3D(Pooling3D): 745 """Max pooling operation for 3D data (spatial or spatio-temporal). 746 747 Downsamples the input along its spatial dimensions (depth, height, and width) 748 by taking the maximum value over an input window 749 (of size defined by `pool_size`) for each channel of the input. 750 The window is shifted by `strides` along each dimension. 751 752 Args: 753 pool_size: Tuple of 3 integers, 754 factors by which to downscale (dim1, dim2, dim3). 755 `(2, 2, 2)` will halve the size of the 3D input in each dimension. 756 strides: tuple of 3 integers, or None. Strides values. 757 padding: One of `"valid"` or `"same"` (case-insensitive). 758 `"valid"` means no padding. `"same"` results in padding evenly to 759 the left/right or up/down of the input such that output has the same 760 height/width dimension as the input. 761 data_format: A string, 762 one of `channels_last` (default) or `channels_first`. 763 The ordering of the dimensions in the inputs. 764 `channels_last` corresponds to inputs with shape 765 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 766 while `channels_first` corresponds to inputs with shape 767 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 768 It defaults to the `image_data_format` value found in your 769 Keras config file at `~/.keras/keras.json`. 770 If you never set it, then it will be "channels_last". 771 772 Input shape: 773 - If `data_format='channels_last'`: 774 5D tensor with shape: 775 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 776 - If `data_format='channels_first'`: 777 5D tensor with shape: 778 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` 779 780 Output shape: 781 - If `data_format='channels_last'`: 782 5D tensor with shape: 783 `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)` 784 - If `data_format='channels_first'`: 785 5D tensor with shape: 786 `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)` 787 788 Example: 789 790 ```python 791 depth = 30 792 height = 30 793 width = 30 794 input_channels = 3 795 796 inputs = tf.keras.Input(shape=(depth, height, width, input_channels)) 797 layer = tf.keras.layers.MaxPooling3D(pool_size=3) 798 outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3) 799 ``` 800 """ 801 802 def __init__(self, 803 pool_size=(2, 2, 2), 804 strides=None, 805 padding='valid', 806 data_format=None, 807 **kwargs): 808 super(MaxPooling3D, self).__init__( 809 nn.max_pool3d, 810 pool_size=pool_size, strides=strides, 811 padding=padding, data_format=data_format, **kwargs) 812 813 814@keras_export('keras.layers.AveragePooling3D', 'keras.layers.AvgPool3D') 815class AveragePooling3D(Pooling3D): 816 """Average pooling operation for 3D data (spatial or spatio-temporal). 817 818 Downsamples the input along its spatial dimensions (depth, height, and width) 819 by taking the average value over an input window 820 (of size defined by `pool_size`) for each channel of the input. 821 The window is shifted by `strides` along each dimension. 822 823 Args: 824 pool_size: tuple of 3 integers, 825 factors by which to downscale (dim1, dim2, dim3). 826 `(2, 2, 2)` will halve the size of the 3D input in each dimension. 827 strides: tuple of 3 integers, or None. Strides values. 828 padding: One of `"valid"` or `"same"` (case-insensitive). 829 `"valid"` means no padding. `"same"` results in padding evenly to 830 the left/right or up/down of the input such that output has the same 831 height/width dimension as the input. 832 data_format: A string, 833 one of `channels_last` (default) or `channels_first`. 834 The ordering of the dimensions in the inputs. 835 `channels_last` corresponds to inputs with shape 836 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 837 while `channels_first` corresponds to inputs with shape 838 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 839 It defaults to the `image_data_format` value found in your 840 Keras config file at `~/.keras/keras.json`. 841 If you never set it, then it will be "channels_last". 842 843 Input shape: 844 - If `data_format='channels_last'`: 845 5D tensor with shape: 846 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 847 - If `data_format='channels_first'`: 848 5D tensor with shape: 849 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` 850 851 Output shape: 852 - If `data_format='channels_last'`: 853 5D tensor with shape: 854 `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)` 855 - If `data_format='channels_first'`: 856 5D tensor with shape: 857 `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)` 858 859 Example: 860 861 ```python 862 depth = 30 863 height = 30 864 width = 30 865 input_channels = 3 866 867 inputs = tf.keras.Input(shape=(depth, height, width, input_channels)) 868 layer = tf.keras.layers.AveragePooling3D(pool_size=3) 869 outputs = layer(inputs) # Shape: (batch_size, 10, 10, 10, 3) 870 ``` 871 """ 872 873 def __init__(self, 874 pool_size=(2, 2, 2), 875 strides=None, 876 padding='valid', 877 data_format=None, 878 **kwargs): 879 super(AveragePooling3D, self).__init__( 880 nn.avg_pool3d, 881 pool_size=pool_size, strides=strides, 882 padding=padding, data_format=data_format, **kwargs) 883 884 885class GlobalPooling1D(Layer): 886 """Abstract class for different global pooling 1D layers.""" 887 888 def __init__(self, data_format='channels_last', keepdims=False, **kwargs): 889 super(GlobalPooling1D, self).__init__(**kwargs) 890 self.input_spec = InputSpec(ndim=3) 891 self.data_format = conv_utils.normalize_data_format(data_format) 892 self.keepdims = keepdims 893 894 def compute_output_shape(self, input_shape): 895 input_shape = tensor_shape.TensorShape(input_shape).as_list() 896 if self.data_format == 'channels_first': 897 if self.keepdims: 898 return tensor_shape.TensorShape([input_shape[0], input_shape[1], 1]) 899 else: 900 return tensor_shape.TensorShape([input_shape[0], input_shape[1]]) 901 else: 902 if self.keepdims: 903 return tensor_shape.TensorShape([input_shape[0], 1, input_shape[2]]) 904 else: 905 return tensor_shape.TensorShape([input_shape[0], input_shape[2]]) 906 907 def call(self, inputs): 908 raise NotImplementedError 909 910 def get_config(self): 911 config = {'data_format': self.data_format, 'keepdims': self.keepdims} 912 base_config = super(GlobalPooling1D, self).get_config() 913 return dict(list(base_config.items()) + list(config.items())) 914 915 916@keras_export('keras.layers.GlobalAveragePooling1D', 917 'keras.layers.GlobalAvgPool1D') 918class GlobalAveragePooling1D(GlobalPooling1D): 919 """Global average pooling operation for temporal data. 920 921 Examples: 922 923 >>> input_shape = (2, 3, 4) 924 >>> x = tf.random.normal(input_shape) 925 >>> y = tf.keras.layers.GlobalAveragePooling1D()(x) 926 >>> print(y.shape) 927 (2, 4) 928 929 Args: 930 data_format: A string, 931 one of `channels_last` (default) or `channels_first`. 932 The ordering of the dimensions in the inputs. 933 `channels_last` corresponds to inputs with shape 934 `(batch, steps, features)` while `channels_first` 935 corresponds to inputs with shape 936 `(batch, features, steps)`. 937 keepdims: A boolean, whether to keep the temporal dimension or not. 938 If `keepdims` is `False` (default), the rank of the tensor is reduced 939 for spatial dimensions. 940 If `keepdims` is `True`, the temporal dimension are retained with 941 length 1. 942 The behavior is the same as for `tf.reduce_mean` or `np.mean`. 943 944 Call arguments: 945 inputs: A 3D tensor. 946 mask: Binary tensor of shape `(batch_size, steps)` indicating whether 947 a given step should be masked (excluded from the average). 948 949 Input shape: 950 - If `data_format='channels_last'`: 951 3D tensor with shape: 952 `(batch_size, steps, features)` 953 - If `data_format='channels_first'`: 954 3D tensor with shape: 955 `(batch_size, features, steps)` 956 957 Output shape: 958 - If `keepdims`=False: 959 2D tensor with shape `(batch_size, features)`. 960 - If `keepdims`=True: 961 - If `data_format='channels_last'`: 962 3D tensor with shape `(batch_size, 1, features)` 963 - If `data_format='channels_first'`: 964 3D tensor with shape `(batch_size, features, 1)` 965 """ 966 967 def __init__(self, data_format='channels_last', **kwargs): 968 super(GlobalAveragePooling1D, self).__init__(data_format=data_format, 969 **kwargs) 970 self.supports_masking = True 971 972 def call(self, inputs, mask=None): 973 steps_axis = 1 if self.data_format == 'channels_last' else 2 974 if mask is not None: 975 mask = math_ops.cast(mask, inputs[0].dtype) 976 mask = array_ops.expand_dims( 977 mask, 2 if self.data_format == 'channels_last' else 1) 978 inputs *= mask 979 return backend.sum( 980 inputs, axis=steps_axis, 981 keepdims=self.keepdims) / math_ops.reduce_sum( 982 mask, axis=steps_axis, keepdims=self.keepdims) 983 else: 984 return backend.mean(inputs, axis=steps_axis, keepdims=self.keepdims) 985 986 def compute_mask(self, inputs, mask=None): 987 return None 988 989 990@keras_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D') 991class GlobalMaxPooling1D(GlobalPooling1D): 992 """Global max pooling operation for 1D temporal data. 993 994 Downsamples the input representation by taking the maximum value over 995 the time dimension. 996 997 For example: 998 999 >>> x = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) 1000 >>> x = tf.reshape(x, [3, 3, 1]) 1001 >>> x 1002 <tf.Tensor: shape=(3, 3, 1), dtype=float32, numpy= 1003 array([[[1.], [2.], [3.]], 1004 [[4.], [5.], [6.]], 1005 [[7.], [8.], [9.]]], dtype=float32)> 1006 >>> max_pool_1d = tf.keras.layers.GlobalMaxPooling1D() 1007 >>> max_pool_1d(x) 1008 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 1009 array([[3.], 1010 [6.], 1011 [9.], dtype=float32)> 1012 1013 Args: 1014 data_format: A string, 1015 one of `channels_last` (default) or `channels_first`. 1016 The ordering of the dimensions in the inputs. 1017 `channels_last` corresponds to inputs with shape 1018 `(batch, steps, features)` while `channels_first` 1019 corresponds to inputs with shape 1020 `(batch, features, steps)`. 1021 keepdims: A boolean, whether to keep the temporal dimension or not. 1022 If `keepdims` is `False` (default), the rank of the tensor is reduced 1023 for spatial dimensions. 1024 If `keepdims` is `True`, the temporal dimension are retained with 1025 length 1. 1026 The behavior is the same as for `tf.reduce_max` or `np.max`. 1027 1028 Input shape: 1029 - If `data_format='channels_last'`: 1030 3D tensor with shape: 1031 `(batch_size, steps, features)` 1032 - If `data_format='channels_first'`: 1033 3D tensor with shape: 1034 `(batch_size, features, steps)` 1035 1036 Output shape: 1037 - If `keepdims`=False: 1038 2D tensor with shape `(batch_size, features)`. 1039 - If `keepdims`=True: 1040 - If `data_format='channels_last'`: 1041 3D tensor with shape `(batch_size, 1, features)` 1042 - If `data_format='channels_first'`: 1043 3D tensor with shape `(batch_size, features, 1)` 1044 """ 1045 1046 def call(self, inputs): 1047 steps_axis = 1 if self.data_format == 'channels_last' else 2 1048 return backend.max(inputs, axis=steps_axis, keepdims=self.keepdims) 1049 1050 1051class GlobalPooling2D(Layer): 1052 """Abstract class for different global pooling 2D layers. 1053 """ 1054 1055 def __init__(self, data_format=None, keepdims=False, **kwargs): 1056 super(GlobalPooling2D, self).__init__(**kwargs) 1057 self.data_format = conv_utils.normalize_data_format(data_format) 1058 self.input_spec = InputSpec(ndim=4) 1059 self.keepdims = keepdims 1060 1061 def compute_output_shape(self, input_shape): 1062 input_shape = tensor_shape.TensorShape(input_shape).as_list() 1063 if self.data_format == 'channels_last': 1064 if self.keepdims: 1065 return tensor_shape.TensorShape([input_shape[0], 1, 1, input_shape[3]]) 1066 else: 1067 return tensor_shape.TensorShape([input_shape[0], input_shape[3]]) 1068 else: 1069 if self.keepdims: 1070 return tensor_shape.TensorShape([input_shape[0], input_shape[1], 1, 1]) 1071 else: 1072 return tensor_shape.TensorShape([input_shape[0], input_shape[1]]) 1073 1074 def call(self, inputs): 1075 raise NotImplementedError 1076 1077 def get_config(self): 1078 config = {'data_format': self.data_format, 'keepdims': self.keepdims} 1079 base_config = super(GlobalPooling2D, self).get_config() 1080 return dict(list(base_config.items()) + list(config.items())) 1081 1082 1083@keras_export('keras.layers.GlobalAveragePooling2D', 1084 'keras.layers.GlobalAvgPool2D') 1085class GlobalAveragePooling2D(GlobalPooling2D): 1086 """Global average pooling operation for spatial data. 1087 1088 Examples: 1089 1090 >>> input_shape = (2, 4, 5, 3) 1091 >>> x = tf.random.normal(input_shape) 1092 >>> y = tf.keras.layers.GlobalAveragePooling2D()(x) 1093 >>> print(y.shape) 1094 (2, 3) 1095 1096 Args: 1097 data_format: A string, 1098 one of `channels_last` (default) or `channels_first`. 1099 The ordering of the dimensions in the inputs. 1100 `channels_last` corresponds to inputs with shape 1101 `(batch, height, width, channels)` while `channels_first` 1102 corresponds to inputs with shape 1103 `(batch, channels, height, width)`. 1104 It defaults to the `image_data_format` value found in your 1105 Keras config file at `~/.keras/keras.json`. 1106 If you never set it, then it will be "channels_last". 1107 keepdims: A boolean, whether to keep the spatial dimensions or not. 1108 If `keepdims` is `False` (default), the rank of the tensor is reduced 1109 for spatial dimensions. 1110 If `keepdims` is `True`, the spatial dimensions are retained with 1111 length 1. 1112 The behavior is the same as for `tf.reduce_mean` or `np.mean`. 1113 1114 Input shape: 1115 - If `data_format='channels_last'`: 1116 4D tensor with shape `(batch_size, rows, cols, channels)`. 1117 - If `data_format='channels_first'`: 1118 4D tensor with shape `(batch_size, channels, rows, cols)`. 1119 1120 Output shape: 1121 - If `keepdims`=False: 1122 2D tensor with shape `(batch_size, channels)`. 1123 - If `keepdims`=True: 1124 - If `data_format='channels_last'`: 1125 4D tensor with shape `(batch_size, 1, 1, channels)` 1126 - If `data_format='channels_first'`: 1127 4D tensor with shape `(batch_size, channels, 1, 1)` 1128 """ 1129 1130 def call(self, inputs): 1131 if self.data_format == 'channels_last': 1132 return backend.mean(inputs, axis=[1, 2], keepdims=self.keepdims) 1133 else: 1134 return backend.mean(inputs, axis=[2, 3], keepdims=self.keepdims) 1135 1136 1137@keras_export('keras.layers.GlobalMaxPool2D', 'keras.layers.GlobalMaxPooling2D') 1138class GlobalMaxPooling2D(GlobalPooling2D): 1139 """Global max pooling operation for spatial data. 1140 1141 Examples: 1142 1143 >>> input_shape = (2, 4, 5, 3) 1144 >>> x = tf.random.normal(input_shape) 1145 >>> y = tf.keras.layers.GlobalMaxPool2D()(x) 1146 >>> print(y.shape) 1147 (2, 3) 1148 1149 Args: 1150 data_format: A string, 1151 one of `channels_last` (default) or `channels_first`. 1152 The ordering of the dimensions in the inputs. 1153 `channels_last` corresponds to inputs with shape 1154 `(batch, height, width, channels)` while `channels_first` 1155 corresponds to inputs with shape 1156 `(batch, channels, height, width)`. 1157 It defaults to the `image_data_format` value found in your 1158 Keras config file at `~/.keras/keras.json`. 1159 If you never set it, then it will be "channels_last". 1160 keepdims: A boolean, whether to keep the spatial dimensions or not. 1161 If `keepdims` is `False` (default), the rank of the tensor is reduced 1162 for spatial dimensions. 1163 If `keepdims` is `True`, the spatial dimensions are retained with 1164 length 1. 1165 The behavior is the same as for `tf.reduce_max` or `np.max`. 1166 1167 Input shape: 1168 - If `data_format='channels_last'`: 1169 4D tensor with shape `(batch_size, rows, cols, channels)`. 1170 - If `data_format='channels_first'`: 1171 4D tensor with shape `(batch_size, channels, rows, cols)`. 1172 1173 Output shape: 1174 - If `keepdims`=False: 1175 2D tensor with shape `(batch_size, channels)`. 1176 - If `keepdims`=True: 1177 - If `data_format='channels_last'`: 1178 4D tensor with shape `(batch_size, 1, 1, channels)` 1179 - If `data_format='channels_first'`: 1180 4D tensor with shape `(batch_size, channels, 1, 1)` 1181 """ 1182 1183 def call(self, inputs): 1184 if self.data_format == 'channels_last': 1185 return backend.max(inputs, axis=[1, 2], keepdims=self.keepdims) 1186 else: 1187 return backend.max(inputs, axis=[2, 3], keepdims=self.keepdims) 1188 1189 1190class GlobalPooling3D(Layer): 1191 """Abstract class for different global pooling 3D layers.""" 1192 1193 def __init__(self, data_format=None, keepdims=False, **kwargs): 1194 super(GlobalPooling3D, self).__init__(**kwargs) 1195 self.data_format = conv_utils.normalize_data_format(data_format) 1196 self.input_spec = InputSpec(ndim=5) 1197 self.keepdims = keepdims 1198 1199 def compute_output_shape(self, input_shape): 1200 input_shape = tensor_shape.TensorShape(input_shape).as_list() 1201 if self.data_format == 'channels_last': 1202 if self.keepdims: 1203 return tensor_shape.TensorShape( 1204 [input_shape[0], 1, 1, 1, input_shape[4]]) 1205 else: 1206 return tensor_shape.TensorShape([input_shape[0], input_shape[4]]) 1207 else: 1208 if self.keepdims: 1209 return tensor_shape.TensorShape( 1210 [input_shape[0], input_shape[1], 1, 1, 1]) 1211 else: 1212 return tensor_shape.TensorShape([input_shape[0], input_shape[1]]) 1213 1214 def call(self, inputs): 1215 raise NotImplementedError 1216 1217 def get_config(self): 1218 config = {'data_format': self.data_format, 'keepdims': self.keepdims} 1219 base_config = super(GlobalPooling3D, self).get_config() 1220 return dict(list(base_config.items()) + list(config.items())) 1221 1222 1223@keras_export('keras.layers.GlobalAveragePooling3D', 1224 'keras.layers.GlobalAvgPool3D') 1225class GlobalAveragePooling3D(GlobalPooling3D): 1226 """Global Average pooling operation for 3D data. 1227 1228 Args: 1229 data_format: A string, 1230 one of `channels_last` (default) or `channels_first`. 1231 The ordering of the dimensions in the inputs. 1232 `channels_last` corresponds to inputs with shape 1233 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 1234 while `channels_first` corresponds to inputs with shape 1235 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 1236 It defaults to the `image_data_format` value found in your 1237 Keras config file at `~/.keras/keras.json`. 1238 If you never set it, then it will be "channels_last". 1239 keepdims: A boolean, whether to keep the spatial dimensions or not. 1240 If `keepdims` is `False` (default), the rank of the tensor is reduced 1241 for spatial dimensions. 1242 If `keepdims` is `True`, the spatial dimensions are retained with 1243 length 1. 1244 The behavior is the same as for `tf.reduce_mean` or `np.mean`. 1245 1246 Input shape: 1247 - If `data_format='channels_last'`: 1248 5D tensor with shape: 1249 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 1250 - If `data_format='channels_first'`: 1251 5D tensor with shape: 1252 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` 1253 1254 Output shape: 1255 - If `keepdims`=False: 1256 2D tensor with shape `(batch_size, channels)`. 1257 - If `keepdims`=True: 1258 - If `data_format='channels_last'`: 1259 5D tensor with shape `(batch_size, 1, 1, 1, channels)` 1260 - If `data_format='channels_first'`: 1261 5D tensor with shape `(batch_size, channels, 1, 1, 1)` 1262 """ 1263 1264 def call(self, inputs): 1265 if self.data_format == 'channels_last': 1266 return backend.mean(inputs, axis=[1, 2, 3], keepdims=self.keepdims) 1267 else: 1268 return backend.mean(inputs, axis=[2, 3, 4], keepdims=self.keepdims) 1269 1270 1271@keras_export('keras.layers.GlobalMaxPool3D', 'keras.layers.GlobalMaxPooling3D') 1272class GlobalMaxPooling3D(GlobalPooling3D): 1273 """Global Max pooling operation for 3D data. 1274 1275 Args: 1276 data_format: A string, 1277 one of `channels_last` (default) or `channels_first`. 1278 The ordering of the dimensions in the inputs. 1279 `channels_last` corresponds to inputs with shape 1280 `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 1281 while `channels_first` corresponds to inputs with shape 1282 `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 1283 It defaults to the `image_data_format` value found in your 1284 Keras config file at `~/.keras/keras.json`. 1285 If you never set it, then it will be "channels_last". 1286 keepdims: A boolean, whether to keep the spatial dimensions or not. 1287 If `keepdims` is `False` (default), the rank of the tensor is reduced 1288 for spatial dimensions. 1289 If `keepdims` is `True`, the spatial dimensions are retained with 1290 length 1. 1291 The behavior is the same as for `tf.reduce_max` or `np.max`. 1292 1293 Input shape: 1294 - If `data_format='channels_last'`: 1295 5D tensor with shape: 1296 `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 1297 - If `data_format='channels_first'`: 1298 5D tensor with shape: 1299 `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)` 1300 1301 Output shape: 1302 - If `keepdims`=False: 1303 2D tensor with shape `(batch_size, channels)`. 1304 - If `keepdims`=True: 1305 - If `data_format='channels_last'`: 1306 5D tensor with shape `(batch_size, 1, 1, 1, channels)` 1307 - If `data_format='channels_first'`: 1308 5D tensor with shape `(batch_size, channels, 1, 1, 1)` 1309 """ 1310 1311 def call(self, inputs): 1312 if self.data_format == 'channels_last': 1313 return backend.max(inputs, axis=[1, 2, 3], keepdims=self.keepdims) 1314 else: 1315 return backend.max(inputs, axis=[2, 3, 4], keepdims=self.keepdims) 1316 1317 1318# Aliases 1319 1320AvgPool1D = AveragePooling1D 1321MaxPool1D = MaxPooling1D 1322AvgPool2D = AveragePooling2D 1323MaxPool2D = MaxPooling2D 1324AvgPool3D = AveragePooling3D 1325MaxPool3D = MaxPooling3D 1326GlobalMaxPool1D = GlobalMaxPooling1D 1327GlobalMaxPool2D = GlobalMaxPooling2D 1328GlobalMaxPool3D = GlobalMaxPooling3D 1329GlobalAvgPool1D = GlobalAveragePooling1D 1330GlobalAvgPool2D = GlobalAveragePooling2D 1331GlobalAvgPool3D = GlobalAveragePooling3D 1332