1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Convolutional-recurrent layers. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.keras import activations 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras import constraints 27from tensorflow.python.keras import initializers 28from tensorflow.python.keras import regularizers 29from tensorflow.python.keras.engine.base_layer import Layer 30from tensorflow.python.keras.engine.input_spec import InputSpec 31from tensorflow.python.keras.layers.recurrent import _standardize_args 32from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin 33from tensorflow.python.keras.layers.recurrent import RNN 34from tensorflow.python.keras.utils import conv_utils 35from tensorflow.python.keras.utils import generic_utils 36from tensorflow.python.keras.utils import tf_utils 37from tensorflow.python.ops import array_ops 38from tensorflow.python.util.tf_export import keras_export 39 40 41class ConvRNN2D(RNN): 42 """Base class for convolutional-recurrent layers. 43 44 Arguments: 45 cell: A RNN cell instance. A RNN cell is a class that has: 46 - a `call(input_at_t, states_at_t)` method, returning 47 `(output_at_t, states_at_t_plus_1)`. The call method of the 48 cell can also take the optional argument `constants`, see 49 section "Note on passing external constants" below. 50 - a `state_size` attribute. This can be a single integer 51 (single state) in which case it is 52 the number of channels of the recurrent state 53 (which should be the same as the number of channels of the cell 54 output). This can also be a list/tuple of integers 55 (one size per state). In this case, the first entry 56 (`state_size[0]`) should be the same as 57 the size of the cell output. 58 return_sequences: Boolean. Whether to return the last output. 59 in the output sequence, or the full sequence. 60 return_state: Boolean. Whether to return the last state 61 in addition to the output. 62 go_backwards: Boolean (default False). 63 If True, process the input sequence backwards and return the 64 reversed sequence. 65 stateful: Boolean (default False). If True, the last state 66 for each sample at index i in a batch will be used as initial 67 state for the sample of index i in the following batch. 68 input_shape: Use this argument to specify the shape of the 69 input when this layer is the first one in a model. 70 71 Call arguments: 72 inputs: A 5D tensor. 73 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 74 a given timestep should be masked. 75 training: Python boolean indicating whether the layer should behave in 76 training mode or in inference mode. This argument is passed to the cell 77 when calling it. This is for use with cells that use dropout. 78 initial_state: List of initial state tensors to be passed to the first 79 call of the cell. 80 constants: List of constant tensors to be passed to the cell at each 81 timestep. 82 83 Input shape: 84 5D tensor with shape: 85 `(samples, timesteps, channels, rows, cols)` 86 if data_format='channels_first' or 5D tensor with shape: 87 `(samples, timesteps, rows, cols, channels)` 88 if data_format='channels_last'. 89 90 Output shape: 91 - If `return_state`: a list of tensors. The first tensor is 92 the output. The remaining tensors are the last states, 93 each 4D tensor with shape: 94 `(samples, filters, new_rows, new_cols)` 95 if data_format='channels_first' 96 or 4D tensor with shape: 97 `(samples, new_rows, new_cols, filters)` 98 if data_format='channels_last'. 99 `rows` and `cols` values might have changed due to padding. 100 - If `return_sequences`: 5D tensor with shape: 101 `(samples, timesteps, filters, new_rows, new_cols)` 102 if data_format='channels_first' 103 or 5D tensor with shape: 104 `(samples, timesteps, new_rows, new_cols, filters)` 105 if data_format='channels_last'. 106 - Else, 4D tensor with shape: 107 `(samples, filters, new_rows, new_cols)` 108 if data_format='channels_first' 109 or 4D tensor with shape: 110 `(samples, new_rows, new_cols, filters)` 111 if data_format='channels_last'. 112 113 Masking: 114 This layer supports masking for input data with a variable number 115 of timesteps. 116 117 Note on using statefulness in RNNs: 118 You can set RNN layers to be 'stateful', which means that the states 119 computed for the samples in one batch will be reused as initial states 120 for the samples in the next batch. This assumes a one-to-one mapping 121 between samples in different successive batches. 122 To enable statefulness: 123 - Specify `stateful=True` in the layer constructor. 124 - Specify a fixed batch size for your model, by passing 125 - If sequential model: 126 `batch_input_shape=(...)` to the first layer in your model. 127 - If functional model with 1 or more Input layers: 128 `batch_shape=(...)` to all the first layers in your model. 129 This is the expected shape of your inputs 130 *including the batch size*. 131 It should be a tuple of integers, 132 e.g. `(32, 10, 100, 100, 32)`. 133 Note that the number of rows and columns should be specified 134 too. 135 - Specify `shuffle=False` when calling fit(). 136 To reset the states of your model, call `.reset_states()` on either 137 a specific layer, or on your entire model. 138 139 Note on specifying the initial state of RNNs: 140 You can specify the initial state of RNN layers symbolically by 141 calling them with the keyword argument `initial_state`. The value of 142 `initial_state` should be a tensor or list of tensors representing 143 the initial state of the RNN layer. 144 You can specify the initial state of RNN layers numerically by 145 calling `reset_states` with the keyword argument `states`. The value of 146 `states` should be a numpy array or list of numpy arrays representing 147 the initial state of the RNN layer. 148 149 Note on passing external constants to RNNs: 150 You can pass "external" constants to the cell using the `constants` 151 keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This 152 requires that the `cell.call` method accepts the same keyword argument 153 `constants`. Such constants can be used to condition the cell 154 transformation on additional static inputs (not changing over time), 155 a.k.a. an attention mechanism. 156 """ 157 158 def __init__(self, 159 cell, 160 return_sequences=False, 161 return_state=False, 162 go_backwards=False, 163 stateful=False, 164 unroll=False, 165 **kwargs): 166 if unroll: 167 raise TypeError('Unrolling isn\'t possible with ' 168 'convolutional RNNs.') 169 if isinstance(cell, (list, tuple)): 170 # The StackedConvRNN2DCells isn't implemented yet. 171 raise TypeError('It is not possible at the moment to' 172 'stack convolutional cells.') 173 super(ConvRNN2D, self).__init__(cell, 174 return_sequences, 175 return_state, 176 go_backwards, 177 stateful, 178 unroll, 179 **kwargs) 180 self.input_spec = [InputSpec(ndim=5)] 181 self.states = None 182 self._num_constants = None 183 184 @tf_utils.shape_type_conversion 185 def compute_output_shape(self, input_shape): 186 if isinstance(input_shape, list): 187 input_shape = input_shape[0] 188 189 cell = self.cell 190 if cell.data_format == 'channels_first': 191 rows = input_shape[3] 192 cols = input_shape[4] 193 elif cell.data_format == 'channels_last': 194 rows = input_shape[2] 195 cols = input_shape[3] 196 rows = conv_utils.conv_output_length(rows, 197 cell.kernel_size[0], 198 padding=cell.padding, 199 stride=cell.strides[0], 200 dilation=cell.dilation_rate[0]) 201 cols = conv_utils.conv_output_length(cols, 202 cell.kernel_size[1], 203 padding=cell.padding, 204 stride=cell.strides[1], 205 dilation=cell.dilation_rate[1]) 206 207 if cell.data_format == 'channels_first': 208 output_shape = input_shape[:2] + (cell.filters, rows, cols) 209 elif cell.data_format == 'channels_last': 210 output_shape = input_shape[:2] + (rows, cols, cell.filters) 211 212 if not self.return_sequences: 213 output_shape = output_shape[:1] + output_shape[2:] 214 215 if self.return_state: 216 output_shape = [output_shape] 217 if cell.data_format == 'channels_first': 218 output_shape += [(input_shape[0], cell.filters, rows, cols) 219 for _ in range(2)] 220 elif cell.data_format == 'channels_last': 221 output_shape += [(input_shape[0], rows, cols, cell.filters) 222 for _ in range(2)] 223 return output_shape 224 225 @tf_utils.shape_type_conversion 226 def build(self, input_shape): 227 # Note input_shape will be list of shapes of initial states and 228 # constants if these are passed in __call__. 229 if self._num_constants is not None: 230 constants_shape = input_shape[-self._num_constants:] # pylint: disable=E1130 231 else: 232 constants_shape = None 233 234 if isinstance(input_shape, list): 235 input_shape = input_shape[0] 236 237 batch_size = input_shape[0] if self.stateful else None 238 self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5]) 239 240 # allow cell (if layer) to build before we set or validate state_spec 241 if isinstance(self.cell, Layer): 242 step_input_shape = (input_shape[0],) + input_shape[2:] 243 if constants_shape is not None: 244 self.cell.build([step_input_shape] + constants_shape) 245 else: 246 self.cell.build(step_input_shape) 247 248 # set or validate state_spec 249 if hasattr(self.cell.state_size, '__len__'): 250 state_size = list(self.cell.state_size) 251 else: 252 state_size = [self.cell.state_size] 253 254 if self.state_spec is not None: 255 # initial_state was passed in call, check compatibility 256 if self.cell.data_format == 'channels_first': 257 ch_dim = 1 258 elif self.cell.data_format == 'channels_last': 259 ch_dim = 3 260 if [spec.shape[ch_dim] for spec in self.state_spec] != state_size: 261 raise ValueError( 262 'An initial_state was passed that is not compatible with ' 263 '`cell.state_size`. Received `state_spec`={}; ' 264 'However `cell.state_size` is ' 265 '{}'.format([spec.shape for spec in self.state_spec], 266 self.cell.state_size)) 267 else: 268 if self.cell.data_format == 'channels_first': 269 self.state_spec = [InputSpec(shape=(None, dim, None, None)) 270 for dim in state_size] 271 elif self.cell.data_format == 'channels_last': 272 self.state_spec = [InputSpec(shape=(None, None, None, dim)) 273 for dim in state_size] 274 if self.stateful: 275 self.reset_states() 276 self.built = True 277 278 def get_initial_state(self, inputs): 279 # (samples, timesteps, rows, cols, filters) 280 initial_state = K.zeros_like(inputs) 281 # (samples, rows, cols, filters) 282 initial_state = K.sum(initial_state, axis=1) 283 shape = list(self.cell.kernel_shape) 284 shape[-1] = self.cell.filters 285 initial_state = self.cell.input_conv(initial_state, 286 array_ops.zeros(tuple(shape)), 287 padding=self.cell.padding) 288 289 if hasattr(self.cell.state_size, '__len__'): 290 return [initial_state for _ in self.cell.state_size] 291 else: 292 return [initial_state] 293 294 def __call__(self, inputs, initial_state=None, constants=None, **kwargs): 295 inputs, initial_state, constants = _standardize_args( 296 inputs, initial_state, constants, self._num_constants) 297 298 if initial_state is None and constants is None: 299 return super(ConvRNN2D, self).__call__(inputs, **kwargs) 300 301 # If any of `initial_state` or `constants` are specified and are Keras 302 # tensors, then add them to the inputs and temporarily modify the 303 # input_spec to include them. 304 305 additional_inputs = [] 306 additional_specs = [] 307 if initial_state is not None: 308 kwargs['initial_state'] = initial_state 309 additional_inputs += initial_state 310 self.state_spec = [] 311 for state in initial_state: 312 shape = K.int_shape(state) 313 self.state_spec.append(InputSpec(shape=shape)) 314 315 additional_specs += self.state_spec 316 if constants is not None: 317 kwargs['constants'] = constants 318 additional_inputs += constants 319 self.constants_spec = [InputSpec(shape=K.int_shape(constant)) 320 for constant in constants] 321 self._num_constants = len(constants) 322 additional_specs += self.constants_spec 323 # at this point additional_inputs cannot be empty 324 for tensor in additional_inputs: 325 if K.is_keras_tensor(tensor) != K.is_keras_tensor(additional_inputs[0]): 326 raise ValueError('The initial state or constants of an RNN' 327 ' layer cannot be specified with a mix of' 328 ' Keras tensors and non-Keras tensors') 329 330 if K.is_keras_tensor(additional_inputs[0]): 331 # Compute the full input spec, including state and constants 332 full_input = [inputs] + additional_inputs 333 full_input_spec = self.input_spec + additional_specs 334 # Perform the call with temporarily replaced input_spec 335 original_input_spec = self.input_spec 336 self.input_spec = full_input_spec 337 output = super(ConvRNN2D, self).__call__(full_input, **kwargs) 338 self.input_spec = original_input_spec 339 return output 340 else: 341 return super(ConvRNN2D, self).__call__(inputs, **kwargs) 342 343 def call(self, 344 inputs, 345 mask=None, 346 training=None, 347 initial_state=None, 348 constants=None): 349 # note that the .build() method of subclasses MUST define 350 # self.input_spec and self.state_spec with complete input shapes. 351 if isinstance(inputs, list): 352 inputs = inputs[0] 353 if initial_state is not None: 354 pass 355 elif self.stateful: 356 initial_state = self.states 357 else: 358 initial_state = self.get_initial_state(inputs) 359 360 if isinstance(mask, list): 361 mask = mask[0] 362 363 if len(initial_state) != len(self.states): 364 raise ValueError('Layer has ' + str(len(self.states)) + 365 ' states but was passed ' + 366 str(len(initial_state)) + 367 ' initial states.') 368 timesteps = K.int_shape(inputs)[1] 369 370 kwargs = {} 371 if generic_utils.has_arg(self.cell.call, 'training'): 372 kwargs['training'] = training 373 374 if constants: 375 if not generic_utils.has_arg(self.cell.call, 'constants'): 376 raise ValueError('RNN cell does not support constants') 377 378 def step(inputs, states): 379 constants = states[-self._num_constants:] 380 states = states[:-self._num_constants] 381 return self.cell.call(inputs, states, constants=constants, 382 **kwargs) 383 else: 384 def step(inputs, states): 385 return self.cell.call(inputs, states, **kwargs) 386 387 last_output, outputs, states = K.rnn(step, 388 inputs, 389 initial_state, 390 constants=constants, 391 go_backwards=self.go_backwards, 392 mask=mask, 393 input_length=timesteps) 394 if self.stateful: 395 updates = [] 396 for i in range(len(states)): 397 updates.append(K.update(self.states[i], states[i])) 398 self.add_update(updates, inputs=True) 399 400 if self.return_sequences: 401 output = outputs 402 else: 403 output = last_output 404 405 if self.return_state: 406 if not isinstance(states, (list, tuple)): 407 states = [states] 408 else: 409 states = list(states) 410 return [output] + states 411 else: 412 return output 413 414 def reset_states(self, states=None): 415 if not self.stateful: 416 raise AttributeError('Layer must be stateful.') 417 input_shape = self.input_spec[0].shape 418 state_shape = self.compute_output_shape(input_shape) 419 if self.return_state: 420 state_shape = state_shape[0] 421 if self.return_sequences: 422 state_shape = state_shape[:1].concatenate(state_shape[2:]) 423 if None in state_shape: 424 raise ValueError('If a RNN is stateful, it needs to know ' 425 'its batch size. Specify the batch size ' 426 'of your input tensors: \n' 427 '- If using a Sequential model, ' 428 'specify the batch size by passing ' 429 'a `batch_input_shape` ' 430 'argument to your first layer.\n' 431 '- If using the functional API, specify ' 432 'the time dimension by passing a ' 433 '`batch_shape` argument to your Input layer.\n' 434 'The same thing goes for the number of rows and ' 435 'columns.') 436 437 # helper function 438 def get_tuple_shape(nb_channels): 439 result = list(state_shape) 440 if self.cell.data_format == 'channels_first': 441 result[1] = nb_channels 442 elif self.cell.data_format == 'channels_last': 443 result[3] = nb_channels 444 else: 445 raise KeyError 446 return tuple(result) 447 448 # initialize state if None 449 if self.states[0] is None: 450 if hasattr(self.cell.state_size, '__len__'): 451 self.states = [K.zeros(get_tuple_shape(dim)) 452 for dim in self.cell.state_size] 453 else: 454 self.states = [K.zeros(get_tuple_shape(self.cell.state_size))] 455 elif states is None: 456 if hasattr(self.cell.state_size, '__len__'): 457 for state, dim in zip(self.states, self.cell.state_size): 458 K.set_value(state, np.zeros(get_tuple_shape(dim))) 459 else: 460 K.set_value(self.states[0], 461 np.zeros(get_tuple_shape(self.cell.state_size))) 462 else: 463 if not isinstance(states, (list, tuple)): 464 states = [states] 465 if len(states) != len(self.states): 466 raise ValueError('Layer ' + self.name + ' expects ' + 467 str(len(self.states)) + ' states, ' + 468 'but it received ' + str(len(states)) + 469 ' state values. Input received: ' + str(states)) 470 for index, (value, state) in enumerate(zip(states, self.states)): 471 if hasattr(self.cell.state_size, '__len__'): 472 dim = self.cell.state_size[index] 473 else: 474 dim = self.cell.state_size 475 if value.shape != get_tuple_shape(dim): 476 raise ValueError('State ' + str(index) + 477 ' is incompatible with layer ' + 478 self.name + ': expected shape=' + 479 str(get_tuple_shape(dim)) + 480 ', found shape=' + str(value.shape)) 481 # TODO(anjalisridhar): consider batch calls to `set_value`. 482 K.set_value(state, value) 483 484 485class ConvLSTM2DCell(DropoutRNNCellMixin, Layer): 486 """Cell class for the ConvLSTM2D layer. 487 488 Arguments: 489 filters: Integer, the dimensionality of the output space 490 (i.e. the number of output filters in the convolution). 491 kernel_size: An integer or tuple/list of n integers, specifying the 492 dimensions of the convolution window. 493 strides: An integer or tuple/list of n integers, 494 specifying the strides of the convolution. 495 Specifying any stride value != 1 is incompatible with specifying 496 any `dilation_rate` value != 1. 497 padding: One of `"valid"` or `"same"` (case-insensitive). 498 data_format: A string, 499 one of `channels_last` (default) or `channels_first`. 500 It defaults to the `image_data_format` value found in your 501 Keras config file at `~/.keras/keras.json`. 502 If you never set it, then it will be "channels_last". 503 dilation_rate: An integer or tuple/list of n integers, specifying 504 the dilation rate to use for dilated convolution. 505 Currently, specifying any `dilation_rate` value != 1 is 506 incompatible with specifying any `strides` value != 1. 507 activation: Activation function to use. 508 If you don't specify anything, no activation is applied 509 (ie. "linear" activation: `a(x) = x`). 510 recurrent_activation: Activation function to use 511 for the recurrent step. 512 use_bias: Boolean, whether the layer uses a bias vector. 513 kernel_initializer: Initializer for the `kernel` weights matrix, 514 used for the linear transformation of the inputs. 515 recurrent_initializer: Initializer for the `recurrent_kernel` 516 weights matrix, 517 used for the linear transformation of the recurrent state. 518 bias_initializer: Initializer for the bias vector. 519 unit_forget_bias: Boolean. 520 If True, add 1 to the bias of the forget gate at initialization. 521 Use in combination with `bias_initializer="zeros"`. 522 This is recommended in [Jozefowicz et al.] 523 (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 524 kernel_regularizer: Regularizer function applied to 525 the `kernel` weights matrix. 526 recurrent_regularizer: Regularizer function applied to 527 the `recurrent_kernel` weights matrix. 528 bias_regularizer: Regularizer function applied to the bias vector. 529 kernel_constraint: Constraint function applied to 530 the `kernel` weights matrix. 531 recurrent_constraint: Constraint function applied to 532 the `recurrent_kernel` weights matrix. 533 bias_constraint: Constraint function applied to the bias vector. 534 dropout: Float between 0 and 1. 535 Fraction of the units to drop for 536 the linear transformation of the inputs. 537 recurrent_dropout: Float between 0 and 1. 538 Fraction of the units to drop for 539 the linear transformation of the recurrent state. 540 541 Call arguments: 542 inputs: A 4D tensor. 543 states: List of state tensors corresponding to the previous timestep. 544 training: Python boolean indicating whether the layer should behave in 545 training mode or in inference mode. Only relevant when `dropout` or 546 `recurrent_dropout` is used. 547 """ 548 549 def __init__(self, 550 filters, 551 kernel_size, 552 strides=(1, 1), 553 padding='valid', 554 data_format=None, 555 dilation_rate=(1, 1), 556 activation='tanh', 557 recurrent_activation='hard_sigmoid', 558 use_bias=True, 559 kernel_initializer='glorot_uniform', 560 recurrent_initializer='orthogonal', 561 bias_initializer='zeros', 562 unit_forget_bias=True, 563 kernel_regularizer=None, 564 recurrent_regularizer=None, 565 bias_regularizer=None, 566 kernel_constraint=None, 567 recurrent_constraint=None, 568 bias_constraint=None, 569 dropout=0., 570 recurrent_dropout=0., 571 **kwargs): 572 super(ConvLSTM2DCell, self).__init__(**kwargs) 573 self.filters = filters 574 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') 575 self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 576 self.padding = conv_utils.normalize_padding(padding) 577 self.data_format = conv_utils.normalize_data_format(data_format) 578 self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 579 'dilation_rate') 580 self.activation = activations.get(activation) 581 self.recurrent_activation = activations.get(recurrent_activation) 582 self.use_bias = use_bias 583 584 self.kernel_initializer = initializers.get(kernel_initializer) 585 self.recurrent_initializer = initializers.get(recurrent_initializer) 586 self.bias_initializer = initializers.get(bias_initializer) 587 self.unit_forget_bias = unit_forget_bias 588 589 self.kernel_regularizer = regularizers.get(kernel_regularizer) 590 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 591 self.bias_regularizer = regularizers.get(bias_regularizer) 592 593 self.kernel_constraint = constraints.get(kernel_constraint) 594 self.recurrent_constraint = constraints.get(recurrent_constraint) 595 self.bias_constraint = constraints.get(bias_constraint) 596 597 self.dropout = min(1., max(0., dropout)) 598 self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 599 self.state_size = (self.filters, self.filters) 600 601 def build(self, input_shape): 602 603 if self.data_format == 'channels_first': 604 channel_axis = 1 605 else: 606 channel_axis = -1 607 if input_shape[channel_axis] is None: 608 raise ValueError('The channel dimension of the inputs ' 609 'should be defined. Found `None`.') 610 input_dim = input_shape[channel_axis] 611 kernel_shape = self.kernel_size + (input_dim, self.filters * 4) 612 self.kernel_shape = kernel_shape 613 recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4) 614 615 self.kernel = self.add_weight(shape=kernel_shape, 616 initializer=self.kernel_initializer, 617 name='kernel', 618 regularizer=self.kernel_regularizer, 619 constraint=self.kernel_constraint) 620 self.recurrent_kernel = self.add_weight( 621 shape=recurrent_kernel_shape, 622 initializer=self.recurrent_initializer, 623 name='recurrent_kernel', 624 regularizer=self.recurrent_regularizer, 625 constraint=self.recurrent_constraint) 626 627 if self.use_bias: 628 if self.unit_forget_bias: 629 630 def bias_initializer(_, *args, **kwargs): 631 return K.concatenate([ 632 self.bias_initializer((self.filters,), *args, **kwargs), 633 initializers.Ones()((self.filters,), *args, **kwargs), 634 self.bias_initializer((self.filters * 2,), *args, **kwargs), 635 ]) 636 else: 637 bias_initializer = self.bias_initializer 638 self.bias = self.add_weight( 639 shape=(self.filters * 4,), 640 name='bias', 641 initializer=bias_initializer, 642 regularizer=self.bias_regularizer, 643 constraint=self.bias_constraint) 644 else: 645 self.bias = None 646 self.built = True 647 648 def call(self, inputs, states, training=None): 649 h_tm1 = states[0] # previous memory state 650 c_tm1 = states[1] # previous carry state 651 652 # dropout matrices for input units 653 dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) 654 # dropout matrices for recurrent units 655 rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( 656 h_tm1, training, count=4) 657 658 if 0 < self.dropout < 1.: 659 inputs_i = inputs * dp_mask[0] 660 inputs_f = inputs * dp_mask[1] 661 inputs_c = inputs * dp_mask[2] 662 inputs_o = inputs * dp_mask[3] 663 else: 664 inputs_i = inputs 665 inputs_f = inputs 666 inputs_c = inputs 667 inputs_o = inputs 668 669 if 0 < self.recurrent_dropout < 1.: 670 h_tm1_i = h_tm1 * rec_dp_mask[0] 671 h_tm1_f = h_tm1 * rec_dp_mask[1] 672 h_tm1_c = h_tm1 * rec_dp_mask[2] 673 h_tm1_o = h_tm1 * rec_dp_mask[3] 674 else: 675 h_tm1_i = h_tm1 676 h_tm1_f = h_tm1 677 h_tm1_c = h_tm1 678 h_tm1_o = h_tm1 679 680 (kernel_i, kernel_f, 681 kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3) 682 (recurrent_kernel_i, 683 recurrent_kernel_f, 684 recurrent_kernel_c, 685 recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3) 686 687 if self.use_bias: 688 bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4) 689 else: 690 bias_i, bias_f, bias_c, bias_o = None, None, None, None 691 692 x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) 693 x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) 694 x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) 695 x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) 696 h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) 697 h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) 698 h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) 699 h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) 700 701 i = self.recurrent_activation(x_i + h_i) 702 f = self.recurrent_activation(x_f + h_f) 703 c = f * c_tm1 + i * self.activation(x_c + h_c) 704 o = self.recurrent_activation(x_o + h_o) 705 h = o * self.activation(c) 706 return h, [h, c] 707 708 def input_conv(self, x, w, b=None, padding='valid'): 709 conv_out = K.conv2d(x, w, strides=self.strides, 710 padding=padding, 711 data_format=self.data_format, 712 dilation_rate=self.dilation_rate) 713 if b is not None: 714 conv_out = K.bias_add(conv_out, b, 715 data_format=self.data_format) 716 return conv_out 717 718 def recurrent_conv(self, x, w): 719 conv_out = K.conv2d(x, w, strides=(1, 1), 720 padding='same', 721 data_format=self.data_format) 722 return conv_out 723 724 def get_config(self): 725 config = {'filters': self.filters, 726 'kernel_size': self.kernel_size, 727 'strides': self.strides, 728 'padding': self.padding, 729 'data_format': self.data_format, 730 'dilation_rate': self.dilation_rate, 731 'activation': activations.serialize(self.activation), 732 'recurrent_activation': activations.serialize( 733 self.recurrent_activation), 734 'use_bias': self.use_bias, 735 'kernel_initializer': initializers.serialize( 736 self.kernel_initializer), 737 'recurrent_initializer': initializers.serialize( 738 self.recurrent_initializer), 739 'bias_initializer': initializers.serialize(self.bias_initializer), 740 'unit_forget_bias': self.unit_forget_bias, 741 'kernel_regularizer': regularizers.serialize( 742 self.kernel_regularizer), 743 'recurrent_regularizer': regularizers.serialize( 744 self.recurrent_regularizer), 745 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 746 'kernel_constraint': constraints.serialize( 747 self.kernel_constraint), 748 'recurrent_constraint': constraints.serialize( 749 self.recurrent_constraint), 750 'bias_constraint': constraints.serialize(self.bias_constraint), 751 'dropout': self.dropout, 752 'recurrent_dropout': self.recurrent_dropout} 753 base_config = super(ConvLSTM2DCell, self).get_config() 754 return dict(list(base_config.items()) + list(config.items())) 755 756 757@keras_export('keras.layers.ConvLSTM2D') 758class ConvLSTM2D(ConvRNN2D): 759 """Convolutional LSTM. 760 761 It is similar to an LSTM layer, but the input transformations 762 and recurrent transformations are both convolutional. 763 764 Arguments: 765 filters: Integer, the dimensionality of the output space 766 (i.e. the number of output filters in the convolution). 767 kernel_size: An integer or tuple/list of n integers, specifying the 768 dimensions of the convolution window. 769 strides: An integer or tuple/list of n integers, 770 specifying the strides of the convolution. 771 Specifying any stride value != 1 is incompatible with specifying 772 any `dilation_rate` value != 1. 773 padding: One of `"valid"` or `"same"` (case-insensitive). 774 data_format: A string, 775 one of `channels_last` (default) or `channels_first`. 776 The ordering of the dimensions in the inputs. 777 `channels_last` corresponds to inputs with shape 778 `(batch, time, ..., channels)` 779 while `channels_first` corresponds to 780 inputs with shape `(batch, time, channels, ...)`. 781 It defaults to the `image_data_format` value found in your 782 Keras config file at `~/.keras/keras.json`. 783 If you never set it, then it will be "channels_last". 784 dilation_rate: An integer or tuple/list of n integers, specifying 785 the dilation rate to use for dilated convolution. 786 Currently, specifying any `dilation_rate` value != 1 is 787 incompatible with specifying any `strides` value != 1. 788 activation: Activation function to use. 789 If you don't specify anything, no activation is applied 790 (ie. "linear" activation: `a(x) = x`). 791 recurrent_activation: Activation function to use 792 for the recurrent step. 793 use_bias: Boolean, whether the layer uses a bias vector. 794 kernel_initializer: Initializer for the `kernel` weights matrix, 795 used for the linear transformation of the inputs. 796 recurrent_initializer: Initializer for the `recurrent_kernel` 797 weights matrix, 798 used for the linear transformation of the recurrent state. 799 bias_initializer: Initializer for the bias vector. 800 unit_forget_bias: Boolean. 801 If True, add 1 to the bias of the forget gate at initialization. 802 Use in combination with `bias_initializer="zeros"`. 803 This is recommended in [Jozefowicz et al.] 804 (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 805 kernel_regularizer: Regularizer function applied to 806 the `kernel` weights matrix. 807 recurrent_regularizer: Regularizer function applied to 808 the `recurrent_kernel` weights matrix. 809 bias_regularizer: Regularizer function applied to the bias vector. 810 activity_regularizer: Regularizer function applied to. 811 kernel_constraint: Constraint function applied to 812 the `kernel` weights matrix. 813 recurrent_constraint: Constraint function applied to 814 the `recurrent_kernel` weights matrix. 815 bias_constraint: Constraint function applied to the bias vector. 816 return_sequences: Boolean. Whether to return the last output 817 in the output sequence, or the full sequence. 818 go_backwards: Boolean (default False). 819 If True, process the input sequence backwards. 820 stateful: Boolean (default False). If True, the last state 821 for each sample at index i in a batch will be used as initial 822 state for the sample of index i in the following batch. 823 dropout: Float between 0 and 1. 824 Fraction of the units to drop for 825 the linear transformation of the inputs. 826 recurrent_dropout: Float between 0 and 1. 827 Fraction of the units to drop for 828 the linear transformation of the recurrent state. 829 830 Call arguments: 831 inputs: A 5D tensor. 832 mask: Binary tensor of shape `(samples, timesteps)` indicating whether 833 a given timestep should be masked. 834 training: Python boolean indicating whether the layer should behave in 835 training mode or in inference mode. This argument is passed to the cell 836 when calling it. This is only relevant if `dropout` or `recurrent_dropout` 837 are set. 838 initial_state: List of initial state tensors to be passed to the first 839 call of the cell. 840 841 Input shape: 842 - If data_format='channels_first' 843 5D tensor with shape: 844 `(samples, time, channels, rows, cols)` 845 - If data_format='channels_last' 846 5D tensor with shape: 847 `(samples, time, rows, cols, channels)` 848 849 Output shape: 850 - If `return_sequences` 851 - If data_format='channels_first' 852 5D tensor with shape: 853 `(samples, time, filters, output_row, output_col)` 854 - If data_format='channels_last' 855 5D tensor with shape: 856 `(samples, time, output_row, output_col, filters)` 857 - Else 858 - If data_format ='channels_first' 859 4D tensor with shape: 860 `(samples, filters, output_row, output_col)` 861 - If data_format='channels_last' 862 4D tensor with shape: 863 `(samples, output_row, output_col, filters)` 864 where `o_row` and `o_col` depend on the shape of the filter and 865 the padding 866 867 Raises: 868 ValueError: in case of invalid constructor arguments. 869 870 References: 871 - [Convolutional LSTM Network: A Machine Learning Approach for 872 Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1) 873 The current implementation does not include the feedback loop on the 874 cells output. 875 """ 876 877 def __init__(self, 878 filters, 879 kernel_size, 880 strides=(1, 1), 881 padding='valid', 882 data_format=None, 883 dilation_rate=(1, 1), 884 activation='tanh', 885 recurrent_activation='hard_sigmoid', 886 use_bias=True, 887 kernel_initializer='glorot_uniform', 888 recurrent_initializer='orthogonal', 889 bias_initializer='zeros', 890 unit_forget_bias=True, 891 kernel_regularizer=None, 892 recurrent_regularizer=None, 893 bias_regularizer=None, 894 activity_regularizer=None, 895 kernel_constraint=None, 896 recurrent_constraint=None, 897 bias_constraint=None, 898 return_sequences=False, 899 go_backwards=False, 900 stateful=False, 901 dropout=0., 902 recurrent_dropout=0., 903 **kwargs): 904 cell = ConvLSTM2DCell(filters=filters, 905 kernel_size=kernel_size, 906 strides=strides, 907 padding=padding, 908 data_format=data_format, 909 dilation_rate=dilation_rate, 910 activation=activation, 911 recurrent_activation=recurrent_activation, 912 use_bias=use_bias, 913 kernel_initializer=kernel_initializer, 914 recurrent_initializer=recurrent_initializer, 915 bias_initializer=bias_initializer, 916 unit_forget_bias=unit_forget_bias, 917 kernel_regularizer=kernel_regularizer, 918 recurrent_regularizer=recurrent_regularizer, 919 bias_regularizer=bias_regularizer, 920 kernel_constraint=kernel_constraint, 921 recurrent_constraint=recurrent_constraint, 922 bias_constraint=bias_constraint, 923 dropout=dropout, 924 recurrent_dropout=recurrent_dropout) 925 super(ConvLSTM2D, self).__init__(cell, 926 return_sequences=return_sequences, 927 go_backwards=go_backwards, 928 stateful=stateful, 929 **kwargs) 930 self.activity_regularizer = regularizers.get(activity_regularizer) 931 932 def call(self, inputs, mask=None, training=None, initial_state=None): 933 self.cell.reset_dropout_mask() 934 self.cell.reset_recurrent_dropout_mask() 935 return super(ConvLSTM2D, self).call(inputs, 936 mask=mask, 937 training=training, 938 initial_state=initial_state) 939 940 @property 941 def filters(self): 942 return self.cell.filters 943 944 @property 945 def kernel_size(self): 946 return self.cell.kernel_size 947 948 @property 949 def strides(self): 950 return self.cell.strides 951 952 @property 953 def padding(self): 954 return self.cell.padding 955 956 @property 957 def data_format(self): 958 return self.cell.data_format 959 960 @property 961 def dilation_rate(self): 962 return self.cell.dilation_rate 963 964 @property 965 def activation(self): 966 return self.cell.activation 967 968 @property 969 def recurrent_activation(self): 970 return self.cell.recurrent_activation 971 972 @property 973 def use_bias(self): 974 return self.cell.use_bias 975 976 @property 977 def kernel_initializer(self): 978 return self.cell.kernel_initializer 979 980 @property 981 def recurrent_initializer(self): 982 return self.cell.recurrent_initializer 983 984 @property 985 def bias_initializer(self): 986 return self.cell.bias_initializer 987 988 @property 989 def unit_forget_bias(self): 990 return self.cell.unit_forget_bias 991 992 @property 993 def kernel_regularizer(self): 994 return self.cell.kernel_regularizer 995 996 @property 997 def recurrent_regularizer(self): 998 return self.cell.recurrent_regularizer 999 1000 @property 1001 def bias_regularizer(self): 1002 return self.cell.bias_regularizer 1003 1004 @property 1005 def kernel_constraint(self): 1006 return self.cell.kernel_constraint 1007 1008 @property 1009 def recurrent_constraint(self): 1010 return self.cell.recurrent_constraint 1011 1012 @property 1013 def bias_constraint(self): 1014 return self.cell.bias_constraint 1015 1016 @property 1017 def dropout(self): 1018 return self.cell.dropout 1019 1020 @property 1021 def recurrent_dropout(self): 1022 return self.cell.recurrent_dropout 1023 1024 def get_config(self): 1025 config = {'filters': self.filters, 1026 'kernel_size': self.kernel_size, 1027 'strides': self.strides, 1028 'padding': self.padding, 1029 'data_format': self.data_format, 1030 'dilation_rate': self.dilation_rate, 1031 'activation': activations.serialize(self.activation), 1032 'recurrent_activation': activations.serialize( 1033 self.recurrent_activation), 1034 'use_bias': self.use_bias, 1035 'kernel_initializer': initializers.serialize( 1036 self.kernel_initializer), 1037 'recurrent_initializer': initializers.serialize( 1038 self.recurrent_initializer), 1039 'bias_initializer': initializers.serialize(self.bias_initializer), 1040 'unit_forget_bias': self.unit_forget_bias, 1041 'kernel_regularizer': regularizers.serialize( 1042 self.kernel_regularizer), 1043 'recurrent_regularizer': regularizers.serialize( 1044 self.recurrent_regularizer), 1045 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 1046 'activity_regularizer': regularizers.serialize( 1047 self.activity_regularizer), 1048 'kernel_constraint': constraints.serialize( 1049 self.kernel_constraint), 1050 'recurrent_constraint': constraints.serialize( 1051 self.recurrent_constraint), 1052 'bias_constraint': constraints.serialize(self.bias_constraint), 1053 'dropout': self.dropout, 1054 'recurrent_dropout': self.recurrent_dropout} 1055 base_config = super(ConvLSTM2D, self).get_config() 1056 del base_config['cell'] 1057 return dict(list(base_config.items()) + list(config.items())) 1058 1059 @classmethod 1060 def from_config(cls, config): 1061 return cls(**config) 1062