• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Convolutional-recurrent layers.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.keras import activations
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras import constraints
27from tensorflow.python.keras import initializers
28from tensorflow.python.keras import regularizers
29from tensorflow.python.keras.engine.base_layer import Layer
30from tensorflow.python.keras.engine.input_spec import InputSpec
31from tensorflow.python.keras.layers.recurrent import _standardize_args
32from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
33from tensorflow.python.keras.layers.recurrent import RNN
34from tensorflow.python.keras.utils import conv_utils
35from tensorflow.python.keras.utils import generic_utils
36from tensorflow.python.keras.utils import tf_utils
37from tensorflow.python.ops import array_ops
38from tensorflow.python.util.tf_export import keras_export
39
40
41class ConvRNN2D(RNN):
42  """Base class for convolutional-recurrent layers.
43
44  Arguments:
45    cell: A RNN cell instance. A RNN cell is a class that has:
46      - a `call(input_at_t, states_at_t)` method, returning
47        `(output_at_t, states_at_t_plus_1)`. The call method of the
48        cell can also take the optional argument `constants`, see
49        section "Note on passing external constants" below.
50      - a `state_size` attribute. This can be a single integer
51        (single state) in which case it is
52        the number of channels of the recurrent state
53        (which should be the same as the number of channels of the cell
54        output). This can also be a list/tuple of integers
55        (one size per state). In this case, the first entry
56        (`state_size[0]`) should be the same as
57        the size of the cell output.
58    return_sequences: Boolean. Whether to return the last output.
59      in the output sequence, or the full sequence.
60    return_state: Boolean. Whether to return the last state
61      in addition to the output.
62    go_backwards: Boolean (default False).
63      If True, process the input sequence backwards and return the
64      reversed sequence.
65    stateful: Boolean (default False). If True, the last state
66      for each sample at index i in a batch will be used as initial
67      state for the sample of index i in the following batch.
68    input_shape: Use this argument to specify the shape of the
69      input when this layer is the first one in a model.
70
71  Call arguments:
72    inputs: A 5D tensor.
73    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
74      a given timestep should be masked.
75    training: Python boolean indicating whether the layer should behave in
76      training mode or in inference mode. This argument is passed to the cell
77      when calling it. This is for use with cells that use dropout.
78    initial_state: List of initial state tensors to be passed to the first
79      call of the cell.
80    constants: List of constant tensors to be passed to the cell at each
81      timestep.
82
83  Input shape:
84    5D tensor with shape:
85    `(samples, timesteps, channels, rows, cols)`
86    if data_format='channels_first' or 5D tensor with shape:
87    `(samples, timesteps, rows, cols, channels)`
88    if data_format='channels_last'.
89
90  Output shape:
91    - If `return_state`: a list of tensors. The first tensor is
92      the output. The remaining tensors are the last states,
93      each 4D tensor with shape:
94      `(samples, filters, new_rows, new_cols)`
95      if data_format='channels_first'
96      or 4D tensor with shape:
97      `(samples, new_rows, new_cols, filters)`
98      if data_format='channels_last'.
99      `rows` and `cols` values might have changed due to padding.
100    - If `return_sequences`: 5D tensor with shape:
101      `(samples, timesteps, filters, new_rows, new_cols)`
102      if data_format='channels_first'
103      or 5D tensor with shape:
104      `(samples, timesteps, new_rows, new_cols, filters)`
105      if data_format='channels_last'.
106    - Else, 4D tensor with shape:
107      `(samples, filters, new_rows, new_cols)`
108      if data_format='channels_first'
109      or 4D tensor with shape:
110      `(samples, new_rows, new_cols, filters)`
111      if data_format='channels_last'.
112
113  Masking:
114    This layer supports masking for input data with a variable number
115    of timesteps.
116
117  Note on using statefulness in RNNs:
118    You can set RNN layers to be 'stateful', which means that the states
119    computed for the samples in one batch will be reused as initial states
120    for the samples in the next batch. This assumes a one-to-one mapping
121    between samples in different successive batches.
122    To enable statefulness:
123      - Specify `stateful=True` in the layer constructor.
124      - Specify a fixed batch size for your model, by passing
125         - If sequential model:
126            `batch_input_shape=(...)` to the first layer in your model.
127         - If functional model with 1 or more Input layers:
128            `batch_shape=(...)` to all the first layers in your model.
129            This is the expected shape of your inputs
130            *including the batch size*.
131            It should be a tuple of integers,
132            e.g. `(32, 10, 100, 100, 32)`.
133            Note that the number of rows and columns should be specified
134            too.
135      - Specify `shuffle=False` when calling fit().
136    To reset the states of your model, call `.reset_states()` on either
137    a specific layer, or on your entire model.
138
139  Note on specifying the initial state of RNNs:
140    You can specify the initial state of RNN layers symbolically by
141    calling them with the keyword argument `initial_state`. The value of
142    `initial_state` should be a tensor or list of tensors representing
143    the initial state of the RNN layer.
144    You can specify the initial state of RNN layers numerically by
145    calling `reset_states` with the keyword argument `states`. The value of
146    `states` should be a numpy array or list of numpy arrays representing
147    the initial state of the RNN layer.
148
149  Note on passing external constants to RNNs:
150    You can pass "external" constants to the cell using the `constants`
151    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
152    requires that the `cell.call` method accepts the same keyword argument
153    `constants`. Such constants can be used to condition the cell
154    transformation on additional static inputs (not changing over time),
155    a.k.a. an attention mechanism.
156  """
157
158  def __init__(self,
159               cell,
160               return_sequences=False,
161               return_state=False,
162               go_backwards=False,
163               stateful=False,
164               unroll=False,
165               **kwargs):
166    if unroll:
167      raise TypeError('Unrolling isn\'t possible with '
168                      'convolutional RNNs.')
169    if isinstance(cell, (list, tuple)):
170      # The StackedConvRNN2DCells isn't implemented yet.
171      raise TypeError('It is not possible at the moment to'
172                      'stack convolutional cells.')
173    super(ConvRNN2D, self).__init__(cell,
174                                    return_sequences,
175                                    return_state,
176                                    go_backwards,
177                                    stateful,
178                                    unroll,
179                                    **kwargs)
180    self.input_spec = [InputSpec(ndim=5)]
181    self.states = None
182    self._num_constants = None
183
184  @tf_utils.shape_type_conversion
185  def compute_output_shape(self, input_shape):
186    if isinstance(input_shape, list):
187      input_shape = input_shape[0]
188
189    cell = self.cell
190    if cell.data_format == 'channels_first':
191      rows = input_shape[3]
192      cols = input_shape[4]
193    elif cell.data_format == 'channels_last':
194      rows = input_shape[2]
195      cols = input_shape[3]
196    rows = conv_utils.conv_output_length(rows,
197                                         cell.kernel_size[0],
198                                         padding=cell.padding,
199                                         stride=cell.strides[0],
200                                         dilation=cell.dilation_rate[0])
201    cols = conv_utils.conv_output_length(cols,
202                                         cell.kernel_size[1],
203                                         padding=cell.padding,
204                                         stride=cell.strides[1],
205                                         dilation=cell.dilation_rate[1])
206
207    if cell.data_format == 'channels_first':
208      output_shape = input_shape[:2] + (cell.filters, rows, cols)
209    elif cell.data_format == 'channels_last':
210      output_shape = input_shape[:2] + (rows, cols, cell.filters)
211
212    if not self.return_sequences:
213      output_shape = output_shape[:1] + output_shape[2:]
214
215    if self.return_state:
216      output_shape = [output_shape]
217      if cell.data_format == 'channels_first':
218        output_shape += [(input_shape[0], cell.filters, rows, cols)
219                         for _ in range(2)]
220      elif cell.data_format == 'channels_last':
221        output_shape += [(input_shape[0], rows, cols, cell.filters)
222                         for _ in range(2)]
223    return output_shape
224
225  @tf_utils.shape_type_conversion
226  def build(self, input_shape):
227    # Note input_shape will be list of shapes of initial states and
228    # constants if these are passed in __call__.
229    if self._num_constants is not None:
230      constants_shape = input_shape[-self._num_constants:]  # pylint: disable=E1130
231    else:
232      constants_shape = None
233
234    if isinstance(input_shape, list):
235      input_shape = input_shape[0]
236
237    batch_size = input_shape[0] if self.stateful else None
238    self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5])
239
240    # allow cell (if layer) to build before we set or validate state_spec
241    if isinstance(self.cell, Layer):
242      step_input_shape = (input_shape[0],) + input_shape[2:]
243      if constants_shape is not None:
244        self.cell.build([step_input_shape] + constants_shape)
245      else:
246        self.cell.build(step_input_shape)
247
248    # set or validate state_spec
249    if hasattr(self.cell.state_size, '__len__'):
250      state_size = list(self.cell.state_size)
251    else:
252      state_size = [self.cell.state_size]
253
254    if self.state_spec is not None:
255      # initial_state was passed in call, check compatibility
256      if self.cell.data_format == 'channels_first':
257        ch_dim = 1
258      elif self.cell.data_format == 'channels_last':
259        ch_dim = 3
260      if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
261        raise ValueError(
262            'An initial_state was passed that is not compatible with '
263            '`cell.state_size`. Received `state_spec`={}; '
264            'However `cell.state_size` is '
265            '{}'.format([spec.shape for spec in self.state_spec],
266                        self.cell.state_size))
267    else:
268      if self.cell.data_format == 'channels_first':
269        self.state_spec = [InputSpec(shape=(None, dim, None, None))
270                           for dim in state_size]
271      elif self.cell.data_format == 'channels_last':
272        self.state_spec = [InputSpec(shape=(None, None, None, dim))
273                           for dim in state_size]
274    if self.stateful:
275      self.reset_states()
276    self.built = True
277
278  def get_initial_state(self, inputs):
279    # (samples, timesteps, rows, cols, filters)
280    initial_state = K.zeros_like(inputs)
281    # (samples, rows, cols, filters)
282    initial_state = K.sum(initial_state, axis=1)
283    shape = list(self.cell.kernel_shape)
284    shape[-1] = self.cell.filters
285    initial_state = self.cell.input_conv(initial_state,
286                                         array_ops.zeros(tuple(shape)),
287                                         padding=self.cell.padding)
288
289    if hasattr(self.cell.state_size, '__len__'):
290      return [initial_state for _ in self.cell.state_size]
291    else:
292      return [initial_state]
293
294  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
295    inputs, initial_state, constants = _standardize_args(
296        inputs, initial_state, constants, self._num_constants)
297
298    if initial_state is None and constants is None:
299      return super(ConvRNN2D, self).__call__(inputs, **kwargs)
300
301    # If any of `initial_state` or `constants` are specified and are Keras
302    # tensors, then add them to the inputs and temporarily modify the
303    # input_spec to include them.
304
305    additional_inputs = []
306    additional_specs = []
307    if initial_state is not None:
308      kwargs['initial_state'] = initial_state
309      additional_inputs += initial_state
310      self.state_spec = []
311      for state in initial_state:
312        shape = K.int_shape(state)
313        self.state_spec.append(InputSpec(shape=shape))
314
315      additional_specs += self.state_spec
316    if constants is not None:
317      kwargs['constants'] = constants
318      additional_inputs += constants
319      self.constants_spec = [InputSpec(shape=K.int_shape(constant))
320                             for constant in constants]
321      self._num_constants = len(constants)
322      additional_specs += self.constants_spec
323    # at this point additional_inputs cannot be empty
324    for tensor in additional_inputs:
325      if K.is_keras_tensor(tensor) != K.is_keras_tensor(additional_inputs[0]):
326        raise ValueError('The initial state or constants of an RNN'
327                         ' layer cannot be specified with a mix of'
328                         ' Keras tensors and non-Keras tensors')
329
330    if K.is_keras_tensor(additional_inputs[0]):
331      # Compute the full input spec, including state and constants
332      full_input = [inputs] + additional_inputs
333      full_input_spec = self.input_spec + additional_specs
334      # Perform the call with temporarily replaced input_spec
335      original_input_spec = self.input_spec
336      self.input_spec = full_input_spec
337      output = super(ConvRNN2D, self).__call__(full_input, **kwargs)
338      self.input_spec = original_input_spec
339      return output
340    else:
341      return super(ConvRNN2D, self).__call__(inputs, **kwargs)
342
343  def call(self,
344           inputs,
345           mask=None,
346           training=None,
347           initial_state=None,
348           constants=None):
349    # note that the .build() method of subclasses MUST define
350    # self.input_spec and self.state_spec with complete input shapes.
351    if isinstance(inputs, list):
352      inputs = inputs[0]
353    if initial_state is not None:
354      pass
355    elif self.stateful:
356      initial_state = self.states
357    else:
358      initial_state = self.get_initial_state(inputs)
359
360    if isinstance(mask, list):
361      mask = mask[0]
362
363    if len(initial_state) != len(self.states):
364      raise ValueError('Layer has ' + str(len(self.states)) +
365                       ' states but was passed ' +
366                       str(len(initial_state)) +
367                       ' initial states.')
368    timesteps = K.int_shape(inputs)[1]
369
370    kwargs = {}
371    if generic_utils.has_arg(self.cell.call, 'training'):
372      kwargs['training'] = training
373
374    if constants:
375      if not generic_utils.has_arg(self.cell.call, 'constants'):
376        raise ValueError('RNN cell does not support constants')
377
378      def step(inputs, states):
379        constants = states[-self._num_constants:]
380        states = states[:-self._num_constants]
381        return self.cell.call(inputs, states, constants=constants,
382                              **kwargs)
383    else:
384      def step(inputs, states):
385        return self.cell.call(inputs, states, **kwargs)
386
387    last_output, outputs, states = K.rnn(step,
388                                         inputs,
389                                         initial_state,
390                                         constants=constants,
391                                         go_backwards=self.go_backwards,
392                                         mask=mask,
393                                         input_length=timesteps)
394    if self.stateful:
395      updates = []
396      for i in range(len(states)):
397        updates.append(K.update(self.states[i], states[i]))
398      self.add_update(updates, inputs=True)
399
400    if self.return_sequences:
401      output = outputs
402    else:
403      output = last_output
404
405    if self.return_state:
406      if not isinstance(states, (list, tuple)):
407        states = [states]
408      else:
409        states = list(states)
410      return [output] + states
411    else:
412      return output
413
414  def reset_states(self, states=None):
415    if not self.stateful:
416      raise AttributeError('Layer must be stateful.')
417    input_shape = self.input_spec[0].shape
418    state_shape = self.compute_output_shape(input_shape)
419    if self.return_state:
420      state_shape = state_shape[0]
421    if self.return_sequences:
422      state_shape = state_shape[:1].concatenate(state_shape[2:])
423    if None in state_shape:
424      raise ValueError('If a RNN is stateful, it needs to know '
425                       'its batch size. Specify the batch size '
426                       'of your input tensors: \n'
427                       '- If using a Sequential model, '
428                       'specify the batch size by passing '
429                       'a `batch_input_shape` '
430                       'argument to your first layer.\n'
431                       '- If using the functional API, specify '
432                       'the time dimension by passing a '
433                       '`batch_shape` argument to your Input layer.\n'
434                       'The same thing goes for the number of rows and '
435                       'columns.')
436
437    # helper function
438    def get_tuple_shape(nb_channels):
439      result = list(state_shape)
440      if self.cell.data_format == 'channels_first':
441        result[1] = nb_channels
442      elif self.cell.data_format == 'channels_last':
443        result[3] = nb_channels
444      else:
445        raise KeyError
446      return tuple(result)
447
448    # initialize state if None
449    if self.states[0] is None:
450      if hasattr(self.cell.state_size, '__len__'):
451        self.states = [K.zeros(get_tuple_shape(dim))
452                       for dim in self.cell.state_size]
453      else:
454        self.states = [K.zeros(get_tuple_shape(self.cell.state_size))]
455    elif states is None:
456      if hasattr(self.cell.state_size, '__len__'):
457        for state, dim in zip(self.states, self.cell.state_size):
458          K.set_value(state, np.zeros(get_tuple_shape(dim)))
459      else:
460        K.set_value(self.states[0],
461                    np.zeros(get_tuple_shape(self.cell.state_size)))
462    else:
463      if not isinstance(states, (list, tuple)):
464        states = [states]
465      if len(states) != len(self.states):
466        raise ValueError('Layer ' + self.name + ' expects ' +
467                         str(len(self.states)) + ' states, ' +
468                         'but it received ' + str(len(states)) +
469                         ' state values. Input received: ' + str(states))
470      for index, (value, state) in enumerate(zip(states, self.states)):
471        if hasattr(self.cell.state_size, '__len__'):
472          dim = self.cell.state_size[index]
473        else:
474          dim = self.cell.state_size
475        if value.shape != get_tuple_shape(dim):
476          raise ValueError('State ' + str(index) +
477                           ' is incompatible with layer ' +
478                           self.name + ': expected shape=' +
479                           str(get_tuple_shape(dim)) +
480                           ', found shape=' + str(value.shape))
481        # TODO(anjalisridhar): consider batch calls to `set_value`.
482        K.set_value(state, value)
483
484
485class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
486  """Cell class for the ConvLSTM2D layer.
487
488  Arguments:
489    filters: Integer, the dimensionality of the output space
490      (i.e. the number of output filters in the convolution).
491    kernel_size: An integer or tuple/list of n integers, specifying the
492      dimensions of the convolution window.
493    strides: An integer or tuple/list of n integers,
494      specifying the strides of the convolution.
495      Specifying any stride value != 1 is incompatible with specifying
496      any `dilation_rate` value != 1.
497    padding: One of `"valid"` or `"same"` (case-insensitive).
498    data_format: A string,
499      one of `channels_last` (default) or `channels_first`.
500      It defaults to the `image_data_format` value found in your
501      Keras config file at `~/.keras/keras.json`.
502      If you never set it, then it will be "channels_last".
503    dilation_rate: An integer or tuple/list of n integers, specifying
504      the dilation rate to use for dilated convolution.
505      Currently, specifying any `dilation_rate` value != 1 is
506      incompatible with specifying any `strides` value != 1.
507    activation: Activation function to use.
508      If you don't specify anything, no activation is applied
509      (ie. "linear" activation: `a(x) = x`).
510    recurrent_activation: Activation function to use
511      for the recurrent step.
512    use_bias: Boolean, whether the layer uses a bias vector.
513    kernel_initializer: Initializer for the `kernel` weights matrix,
514      used for the linear transformation of the inputs.
515    recurrent_initializer: Initializer for the `recurrent_kernel`
516      weights matrix,
517      used for the linear transformation of the recurrent state.
518    bias_initializer: Initializer for the bias vector.
519    unit_forget_bias: Boolean.
520      If True, add 1 to the bias of the forget gate at initialization.
521      Use in combination with `bias_initializer="zeros"`.
522      This is recommended in [Jozefowicz et al.]
523      (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
524    kernel_regularizer: Regularizer function applied to
525      the `kernel` weights matrix.
526    recurrent_regularizer: Regularizer function applied to
527      the `recurrent_kernel` weights matrix.
528    bias_regularizer: Regularizer function applied to the bias vector.
529    kernel_constraint: Constraint function applied to
530      the `kernel` weights matrix.
531    recurrent_constraint: Constraint function applied to
532      the `recurrent_kernel` weights matrix.
533    bias_constraint: Constraint function applied to the bias vector.
534    dropout: Float between 0 and 1.
535      Fraction of the units to drop for
536      the linear transformation of the inputs.
537    recurrent_dropout: Float between 0 and 1.
538      Fraction of the units to drop for
539      the linear transformation of the recurrent state.
540
541  Call arguments:
542    inputs: A 4D tensor.
543    states:  List of state tensors corresponding to the previous timestep.
544    training: Python boolean indicating whether the layer should behave in
545      training mode or in inference mode. Only relevant when `dropout` or
546      `recurrent_dropout` is used.
547  """
548
549  def __init__(self,
550               filters,
551               kernel_size,
552               strides=(1, 1),
553               padding='valid',
554               data_format=None,
555               dilation_rate=(1, 1),
556               activation='tanh',
557               recurrent_activation='hard_sigmoid',
558               use_bias=True,
559               kernel_initializer='glorot_uniform',
560               recurrent_initializer='orthogonal',
561               bias_initializer='zeros',
562               unit_forget_bias=True,
563               kernel_regularizer=None,
564               recurrent_regularizer=None,
565               bias_regularizer=None,
566               kernel_constraint=None,
567               recurrent_constraint=None,
568               bias_constraint=None,
569               dropout=0.,
570               recurrent_dropout=0.,
571               **kwargs):
572    super(ConvLSTM2DCell, self).__init__(**kwargs)
573    self.filters = filters
574    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
575    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
576    self.padding = conv_utils.normalize_padding(padding)
577    self.data_format = conv_utils.normalize_data_format(data_format)
578    self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
579                                                    'dilation_rate')
580    self.activation = activations.get(activation)
581    self.recurrent_activation = activations.get(recurrent_activation)
582    self.use_bias = use_bias
583
584    self.kernel_initializer = initializers.get(kernel_initializer)
585    self.recurrent_initializer = initializers.get(recurrent_initializer)
586    self.bias_initializer = initializers.get(bias_initializer)
587    self.unit_forget_bias = unit_forget_bias
588
589    self.kernel_regularizer = regularizers.get(kernel_regularizer)
590    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
591    self.bias_regularizer = regularizers.get(bias_regularizer)
592
593    self.kernel_constraint = constraints.get(kernel_constraint)
594    self.recurrent_constraint = constraints.get(recurrent_constraint)
595    self.bias_constraint = constraints.get(bias_constraint)
596
597    self.dropout = min(1., max(0., dropout))
598    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
599    self.state_size = (self.filters, self.filters)
600
601  def build(self, input_shape):
602
603    if self.data_format == 'channels_first':
604      channel_axis = 1
605    else:
606      channel_axis = -1
607    if input_shape[channel_axis] is None:
608      raise ValueError('The channel dimension of the inputs '
609                       'should be defined. Found `None`.')
610    input_dim = input_shape[channel_axis]
611    kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
612    self.kernel_shape = kernel_shape
613    recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
614
615    self.kernel = self.add_weight(shape=kernel_shape,
616                                  initializer=self.kernel_initializer,
617                                  name='kernel',
618                                  regularizer=self.kernel_regularizer,
619                                  constraint=self.kernel_constraint)
620    self.recurrent_kernel = self.add_weight(
621        shape=recurrent_kernel_shape,
622        initializer=self.recurrent_initializer,
623        name='recurrent_kernel',
624        regularizer=self.recurrent_regularizer,
625        constraint=self.recurrent_constraint)
626
627    if self.use_bias:
628      if self.unit_forget_bias:
629
630        def bias_initializer(_, *args, **kwargs):
631          return K.concatenate([
632              self.bias_initializer((self.filters,), *args, **kwargs),
633              initializers.Ones()((self.filters,), *args, **kwargs),
634              self.bias_initializer((self.filters * 2,), *args, **kwargs),
635          ])
636      else:
637        bias_initializer = self.bias_initializer
638      self.bias = self.add_weight(
639          shape=(self.filters * 4,),
640          name='bias',
641          initializer=bias_initializer,
642          regularizer=self.bias_regularizer,
643          constraint=self.bias_constraint)
644    else:
645      self.bias = None
646    self.built = True
647
648  def call(self, inputs, states, training=None):
649    h_tm1 = states[0]  # previous memory state
650    c_tm1 = states[1]  # previous carry state
651
652    # dropout matrices for input units
653    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
654    # dropout matrices for recurrent units
655    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
656        h_tm1, training, count=4)
657
658    if 0 < self.dropout < 1.:
659      inputs_i = inputs * dp_mask[0]
660      inputs_f = inputs * dp_mask[1]
661      inputs_c = inputs * dp_mask[2]
662      inputs_o = inputs * dp_mask[3]
663    else:
664      inputs_i = inputs
665      inputs_f = inputs
666      inputs_c = inputs
667      inputs_o = inputs
668
669    if 0 < self.recurrent_dropout < 1.:
670      h_tm1_i = h_tm1 * rec_dp_mask[0]
671      h_tm1_f = h_tm1 * rec_dp_mask[1]
672      h_tm1_c = h_tm1 * rec_dp_mask[2]
673      h_tm1_o = h_tm1 * rec_dp_mask[3]
674    else:
675      h_tm1_i = h_tm1
676      h_tm1_f = h_tm1
677      h_tm1_c = h_tm1
678      h_tm1_o = h_tm1
679
680    (kernel_i, kernel_f,
681     kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
682    (recurrent_kernel_i,
683     recurrent_kernel_f,
684     recurrent_kernel_c,
685     recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
686
687    if self.use_bias:
688      bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
689    else:
690      bias_i, bias_f, bias_c, bias_o = None, None, None, None
691
692    x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
693    x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
694    x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
695    x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
696    h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
697    h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
698    h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
699    h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
700
701    i = self.recurrent_activation(x_i + h_i)
702    f = self.recurrent_activation(x_f + h_f)
703    c = f * c_tm1 + i * self.activation(x_c + h_c)
704    o = self.recurrent_activation(x_o + h_o)
705    h = o * self.activation(c)
706    return h, [h, c]
707
708  def input_conv(self, x, w, b=None, padding='valid'):
709    conv_out = K.conv2d(x, w, strides=self.strides,
710                        padding=padding,
711                        data_format=self.data_format,
712                        dilation_rate=self.dilation_rate)
713    if b is not None:
714      conv_out = K.bias_add(conv_out, b,
715                            data_format=self.data_format)
716    return conv_out
717
718  def recurrent_conv(self, x, w):
719    conv_out = K.conv2d(x, w, strides=(1, 1),
720                        padding='same',
721                        data_format=self.data_format)
722    return conv_out
723
724  def get_config(self):
725    config = {'filters': self.filters,
726              'kernel_size': self.kernel_size,
727              'strides': self.strides,
728              'padding': self.padding,
729              'data_format': self.data_format,
730              'dilation_rate': self.dilation_rate,
731              'activation': activations.serialize(self.activation),
732              'recurrent_activation': activations.serialize(
733                  self.recurrent_activation),
734              'use_bias': self.use_bias,
735              'kernel_initializer': initializers.serialize(
736                  self.kernel_initializer),
737              'recurrent_initializer': initializers.serialize(
738                  self.recurrent_initializer),
739              'bias_initializer': initializers.serialize(self.bias_initializer),
740              'unit_forget_bias': self.unit_forget_bias,
741              'kernel_regularizer': regularizers.serialize(
742                  self.kernel_regularizer),
743              'recurrent_regularizer': regularizers.serialize(
744                  self.recurrent_regularizer),
745              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
746              'kernel_constraint': constraints.serialize(
747                  self.kernel_constraint),
748              'recurrent_constraint': constraints.serialize(
749                  self.recurrent_constraint),
750              'bias_constraint': constraints.serialize(self.bias_constraint),
751              'dropout': self.dropout,
752              'recurrent_dropout': self.recurrent_dropout}
753    base_config = super(ConvLSTM2DCell, self).get_config()
754    return dict(list(base_config.items()) + list(config.items()))
755
756
757@keras_export('keras.layers.ConvLSTM2D')
758class ConvLSTM2D(ConvRNN2D):
759  """Convolutional LSTM.
760
761  It is similar to an LSTM layer, but the input transformations
762  and recurrent transformations are both convolutional.
763
764  Arguments:
765    filters: Integer, the dimensionality of the output space
766      (i.e. the number of output filters in the convolution).
767    kernel_size: An integer or tuple/list of n integers, specifying the
768      dimensions of the convolution window.
769    strides: An integer or tuple/list of n integers,
770      specifying the strides of the convolution.
771      Specifying any stride value != 1 is incompatible with specifying
772      any `dilation_rate` value != 1.
773    padding: One of `"valid"` or `"same"` (case-insensitive).
774    data_format: A string,
775      one of `channels_last` (default) or `channels_first`.
776      The ordering of the dimensions in the inputs.
777      `channels_last` corresponds to inputs with shape
778      `(batch, time, ..., channels)`
779      while `channels_first` corresponds to
780      inputs with shape `(batch, time, channels, ...)`.
781      It defaults to the `image_data_format` value found in your
782      Keras config file at `~/.keras/keras.json`.
783      If you never set it, then it will be "channels_last".
784    dilation_rate: An integer or tuple/list of n integers, specifying
785      the dilation rate to use for dilated convolution.
786      Currently, specifying any `dilation_rate` value != 1 is
787      incompatible with specifying any `strides` value != 1.
788    activation: Activation function to use.
789      If you don't specify anything, no activation is applied
790      (ie. "linear" activation: `a(x) = x`).
791    recurrent_activation: Activation function to use
792      for the recurrent step.
793    use_bias: Boolean, whether the layer uses a bias vector.
794    kernel_initializer: Initializer for the `kernel` weights matrix,
795      used for the linear transformation of the inputs.
796    recurrent_initializer: Initializer for the `recurrent_kernel`
797      weights matrix,
798      used for the linear transformation of the recurrent state.
799    bias_initializer: Initializer for the bias vector.
800    unit_forget_bias: Boolean.
801      If True, add 1 to the bias of the forget gate at initialization.
802      Use in combination with `bias_initializer="zeros"`.
803      This is recommended in [Jozefowicz et al.]
804      (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
805    kernel_regularizer: Regularizer function applied to
806      the `kernel` weights matrix.
807    recurrent_regularizer: Regularizer function applied to
808      the `recurrent_kernel` weights matrix.
809    bias_regularizer: Regularizer function applied to the bias vector.
810    activity_regularizer: Regularizer function applied to.
811    kernel_constraint: Constraint function applied to
812      the `kernel` weights matrix.
813    recurrent_constraint: Constraint function applied to
814      the `recurrent_kernel` weights matrix.
815    bias_constraint: Constraint function applied to the bias vector.
816    return_sequences: Boolean. Whether to return the last output
817      in the output sequence, or the full sequence.
818    go_backwards: Boolean (default False).
819      If True, process the input sequence backwards.
820    stateful: Boolean (default False). If True, the last state
821      for each sample at index i in a batch will be used as initial
822      state for the sample of index i in the following batch.
823    dropout: Float between 0 and 1.
824      Fraction of the units to drop for
825      the linear transformation of the inputs.
826    recurrent_dropout: Float between 0 and 1.
827      Fraction of the units to drop for
828      the linear transformation of the recurrent state.
829
830  Call arguments:
831    inputs: A 5D tensor.
832    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
833      a given timestep should be masked.
834    training: Python boolean indicating whether the layer should behave in
835      training mode or in inference mode. This argument is passed to the cell
836      when calling it. This is only relevant if `dropout` or `recurrent_dropout`
837      are set.
838    initial_state: List of initial state tensors to be passed to the first
839      call of the cell.
840
841  Input shape:
842    - If data_format='channels_first'
843        5D tensor with shape:
844        `(samples, time, channels, rows, cols)`
845    - If data_format='channels_last'
846        5D tensor with shape:
847        `(samples, time, rows, cols, channels)`
848
849  Output shape:
850    - If `return_sequences`
851       - If data_format='channels_first'
852          5D tensor with shape:
853          `(samples, time, filters, output_row, output_col)`
854       - If data_format='channels_last'
855          5D tensor with shape:
856          `(samples, time, output_row, output_col, filters)`
857    - Else
858      - If data_format ='channels_first'
859          4D tensor with shape:
860          `(samples, filters, output_row, output_col)`
861      - If data_format='channels_last'
862          4D tensor with shape:
863          `(samples, output_row, output_col, filters)`
864      where `o_row` and `o_col` depend on the shape of the filter and
865      the padding
866
867  Raises:
868    ValueError: in case of invalid constructor arguments.
869
870  References:
871    - [Convolutional LSTM Network: A Machine Learning Approach for
872    Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1)
873    The current implementation does not include the feedback loop on the
874    cells output.
875  """
876
877  def __init__(self,
878               filters,
879               kernel_size,
880               strides=(1, 1),
881               padding='valid',
882               data_format=None,
883               dilation_rate=(1, 1),
884               activation='tanh',
885               recurrent_activation='hard_sigmoid',
886               use_bias=True,
887               kernel_initializer='glorot_uniform',
888               recurrent_initializer='orthogonal',
889               bias_initializer='zeros',
890               unit_forget_bias=True,
891               kernel_regularizer=None,
892               recurrent_regularizer=None,
893               bias_regularizer=None,
894               activity_regularizer=None,
895               kernel_constraint=None,
896               recurrent_constraint=None,
897               bias_constraint=None,
898               return_sequences=False,
899               go_backwards=False,
900               stateful=False,
901               dropout=0.,
902               recurrent_dropout=0.,
903               **kwargs):
904    cell = ConvLSTM2DCell(filters=filters,
905                          kernel_size=kernel_size,
906                          strides=strides,
907                          padding=padding,
908                          data_format=data_format,
909                          dilation_rate=dilation_rate,
910                          activation=activation,
911                          recurrent_activation=recurrent_activation,
912                          use_bias=use_bias,
913                          kernel_initializer=kernel_initializer,
914                          recurrent_initializer=recurrent_initializer,
915                          bias_initializer=bias_initializer,
916                          unit_forget_bias=unit_forget_bias,
917                          kernel_regularizer=kernel_regularizer,
918                          recurrent_regularizer=recurrent_regularizer,
919                          bias_regularizer=bias_regularizer,
920                          kernel_constraint=kernel_constraint,
921                          recurrent_constraint=recurrent_constraint,
922                          bias_constraint=bias_constraint,
923                          dropout=dropout,
924                          recurrent_dropout=recurrent_dropout)
925    super(ConvLSTM2D, self).__init__(cell,
926                                     return_sequences=return_sequences,
927                                     go_backwards=go_backwards,
928                                     stateful=stateful,
929                                     **kwargs)
930    self.activity_regularizer = regularizers.get(activity_regularizer)
931
932  def call(self, inputs, mask=None, training=None, initial_state=None):
933    self.cell.reset_dropout_mask()
934    self.cell.reset_recurrent_dropout_mask()
935    return super(ConvLSTM2D, self).call(inputs,
936                                        mask=mask,
937                                        training=training,
938                                        initial_state=initial_state)
939
940  @property
941  def filters(self):
942    return self.cell.filters
943
944  @property
945  def kernel_size(self):
946    return self.cell.kernel_size
947
948  @property
949  def strides(self):
950    return self.cell.strides
951
952  @property
953  def padding(self):
954    return self.cell.padding
955
956  @property
957  def data_format(self):
958    return self.cell.data_format
959
960  @property
961  def dilation_rate(self):
962    return self.cell.dilation_rate
963
964  @property
965  def activation(self):
966    return self.cell.activation
967
968  @property
969  def recurrent_activation(self):
970    return self.cell.recurrent_activation
971
972  @property
973  def use_bias(self):
974    return self.cell.use_bias
975
976  @property
977  def kernel_initializer(self):
978    return self.cell.kernel_initializer
979
980  @property
981  def recurrent_initializer(self):
982    return self.cell.recurrent_initializer
983
984  @property
985  def bias_initializer(self):
986    return self.cell.bias_initializer
987
988  @property
989  def unit_forget_bias(self):
990    return self.cell.unit_forget_bias
991
992  @property
993  def kernel_regularizer(self):
994    return self.cell.kernel_regularizer
995
996  @property
997  def recurrent_regularizer(self):
998    return self.cell.recurrent_regularizer
999
1000  @property
1001  def bias_regularizer(self):
1002    return self.cell.bias_regularizer
1003
1004  @property
1005  def kernel_constraint(self):
1006    return self.cell.kernel_constraint
1007
1008  @property
1009  def recurrent_constraint(self):
1010    return self.cell.recurrent_constraint
1011
1012  @property
1013  def bias_constraint(self):
1014    return self.cell.bias_constraint
1015
1016  @property
1017  def dropout(self):
1018    return self.cell.dropout
1019
1020  @property
1021  def recurrent_dropout(self):
1022    return self.cell.recurrent_dropout
1023
1024  def get_config(self):
1025    config = {'filters': self.filters,
1026              'kernel_size': self.kernel_size,
1027              'strides': self.strides,
1028              'padding': self.padding,
1029              'data_format': self.data_format,
1030              'dilation_rate': self.dilation_rate,
1031              'activation': activations.serialize(self.activation),
1032              'recurrent_activation': activations.serialize(
1033                  self.recurrent_activation),
1034              'use_bias': self.use_bias,
1035              'kernel_initializer': initializers.serialize(
1036                  self.kernel_initializer),
1037              'recurrent_initializer': initializers.serialize(
1038                  self.recurrent_initializer),
1039              'bias_initializer': initializers.serialize(self.bias_initializer),
1040              'unit_forget_bias': self.unit_forget_bias,
1041              'kernel_regularizer': regularizers.serialize(
1042                  self.kernel_regularizer),
1043              'recurrent_regularizer': regularizers.serialize(
1044                  self.recurrent_regularizer),
1045              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1046              'activity_regularizer': regularizers.serialize(
1047                  self.activity_regularizer),
1048              'kernel_constraint': constraints.serialize(
1049                  self.kernel_constraint),
1050              'recurrent_constraint': constraints.serialize(
1051                  self.recurrent_constraint),
1052              'bias_constraint': constraints.serialize(self.bias_constraint),
1053              'dropout': self.dropout,
1054              'recurrent_dropout': self.recurrent_dropout}
1055    base_config = super(ConvLSTM2D, self).get_config()
1056    del base_config['cell']
1057    return dict(list(base_config.items()) + list(config.items()))
1058
1059  @classmethod
1060  def from_config(cls, config):
1061    return cls(**config)
1062