• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Convolutional-recurrent layers."""
18
19import numpy as np
20
21from tensorflow.python.keras import activations
22from tensorflow.python.keras import backend
23from tensorflow.python.keras import constraints
24from tensorflow.python.keras import initializers
25from tensorflow.python.keras import regularizers
26from tensorflow.python.keras.engine.base_layer import Layer
27from tensorflow.python.keras.engine.input_spec import InputSpec
28from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
29from tensorflow.python.keras.layers.recurrent import RNN
30from tensorflow.python.keras.utils import conv_utils
31from tensorflow.python.keras.utils import generic_utils
32from tensorflow.python.keras.utils import tf_utils
33from tensorflow.python.ops import array_ops
34from tensorflow.python.util.tf_export import keras_export
35
36
37class ConvRNN2D(RNN):
38  """Base class for convolutional-recurrent layers.
39
40  Args:
41    cell: A RNN cell instance. A RNN cell is a class that has:
42      - a `call(input_at_t, states_at_t)` method, returning
43        `(output_at_t, states_at_t_plus_1)`. The call method of the
44        cell can also take the optional argument `constants`, see
45        section "Note on passing external constants" below.
46      - a `state_size` attribute. This can be a single integer
47        (single state) in which case it is
48        the number of channels of the recurrent state
49        (which should be the same as the number of channels of the cell
50        output). This can also be a list/tuple of integers
51        (one size per state). In this case, the first entry
52        (`state_size[0]`) should be the same as
53        the size of the cell output.
54    return_sequences: Boolean. Whether to return the last output.
55      in the output sequence, or the full sequence.
56    return_state: Boolean. Whether to return the last state
57      in addition to the output.
58    go_backwards: Boolean (default False).
59      If True, process the input sequence backwards and return the
60      reversed sequence.
61    stateful: Boolean (default False). If True, the last state
62      for each sample at index i in a batch will be used as initial
63      state for the sample of index i in the following batch.
64    input_shape: Use this argument to specify the shape of the
65      input when this layer is the first one in a model.
66
67  Call arguments:
68    inputs: A 5D tensor.
69    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
70      a given timestep should be masked.
71    training: Python boolean indicating whether the layer should behave in
72      training mode or in inference mode. This argument is passed to the cell
73      when calling it. This is for use with cells that use dropout.
74    initial_state: List of initial state tensors to be passed to the first
75      call of the cell.
76    constants: List of constant tensors to be passed to the cell at each
77      timestep.
78
79  Input shape:
80    5D tensor with shape:
81    `(samples, timesteps, channels, rows, cols)`
82    if data_format='channels_first' or 5D tensor with shape:
83    `(samples, timesteps, rows, cols, channels)`
84    if data_format='channels_last'.
85
86  Output shape:
87    - If `return_state`: a list of tensors. The first tensor is
88      the output. The remaining tensors are the last states,
89      each 4D tensor with shape:
90      `(samples, filters, new_rows, new_cols)`
91      if data_format='channels_first'
92      or 4D tensor with shape:
93      `(samples, new_rows, new_cols, filters)`
94      if data_format='channels_last'.
95      `rows` and `cols` values might have changed due to padding.
96    - If `return_sequences`: 5D tensor with shape:
97      `(samples, timesteps, filters, new_rows, new_cols)`
98      if data_format='channels_first'
99      or 5D tensor with shape:
100      `(samples, timesteps, new_rows, new_cols, filters)`
101      if data_format='channels_last'.
102    - Else, 4D tensor with shape:
103      `(samples, filters, new_rows, new_cols)`
104      if data_format='channels_first'
105      or 4D tensor with shape:
106      `(samples, new_rows, new_cols, filters)`
107      if data_format='channels_last'.
108
109  Masking:
110    This layer supports masking for input data with a variable number
111    of timesteps.
112
113  Note on using statefulness in RNNs:
114    You can set RNN layers to be 'stateful', which means that the states
115    computed for the samples in one batch will be reused as initial states
116    for the samples in the next batch. This assumes a one-to-one mapping
117    between samples in different successive batches.
118    To enable statefulness:
119      - Specify `stateful=True` in the layer constructor.
120      - Specify a fixed batch size for your model, by passing
121         - If sequential model:
122            `batch_input_shape=(...)` to the first layer in your model.
123         - If functional model with 1 or more Input layers:
124            `batch_shape=(...)` to all the first layers in your model.
125            This is the expected shape of your inputs
126            *including the batch size*.
127            It should be a tuple of integers,
128            e.g. `(32, 10, 100, 100, 32)`.
129            Note that the number of rows and columns should be specified
130            too.
131      - Specify `shuffle=False` when calling fit().
132    To reset the states of your model, call `.reset_states()` on either
133    a specific layer, or on your entire model.
134
135  Note on specifying the initial state of RNNs:
136    You can specify the initial state of RNN layers symbolically by
137    calling them with the keyword argument `initial_state`. The value of
138    `initial_state` should be a tensor or list of tensors representing
139    the initial state of the RNN layer.
140    You can specify the initial state of RNN layers numerically by
141    calling `reset_states` with the keyword argument `states`. The value of
142    `states` should be a numpy array or list of numpy arrays representing
143    the initial state of the RNN layer.
144
145  Note on passing external constants to RNNs:
146    You can pass "external" constants to the cell using the `constants`
147    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
148    requires that the `cell.call` method accepts the same keyword argument
149    `constants`. Such constants can be used to condition the cell
150    transformation on additional static inputs (not changing over time),
151    a.k.a. an attention mechanism.
152  """
153
154  def __init__(self,
155               cell,
156               return_sequences=False,
157               return_state=False,
158               go_backwards=False,
159               stateful=False,
160               unroll=False,
161               **kwargs):
162    if unroll:
163      raise TypeError('Unrolling isn\'t possible with '
164                      'convolutional RNNs.')
165    if isinstance(cell, (list, tuple)):
166      # The StackedConvRNN2DCells isn't implemented yet.
167      raise TypeError('It is not possible at the moment to'
168                      'stack convolutional cells.')
169    super(ConvRNN2D, self).__init__(cell,
170                                    return_sequences,
171                                    return_state,
172                                    go_backwards,
173                                    stateful,
174                                    unroll,
175                                    **kwargs)
176    self.input_spec = [InputSpec(ndim=5)]
177    self.states = None
178    self._num_constants = None
179
180  @tf_utils.shape_type_conversion
181  def compute_output_shape(self, input_shape):
182    if isinstance(input_shape, list):
183      input_shape = input_shape[0]
184
185    cell = self.cell
186    if cell.data_format == 'channels_first':
187      rows = input_shape[3]
188      cols = input_shape[4]
189    elif cell.data_format == 'channels_last':
190      rows = input_shape[2]
191      cols = input_shape[3]
192    rows = conv_utils.conv_output_length(rows,
193                                         cell.kernel_size[0],
194                                         padding=cell.padding,
195                                         stride=cell.strides[0],
196                                         dilation=cell.dilation_rate[0])
197    cols = conv_utils.conv_output_length(cols,
198                                         cell.kernel_size[1],
199                                         padding=cell.padding,
200                                         stride=cell.strides[1],
201                                         dilation=cell.dilation_rate[1])
202
203    if cell.data_format == 'channels_first':
204      output_shape = input_shape[:2] + (cell.filters, rows, cols)
205    elif cell.data_format == 'channels_last':
206      output_shape = input_shape[:2] + (rows, cols, cell.filters)
207
208    if not self.return_sequences:
209      output_shape = output_shape[:1] + output_shape[2:]
210
211    if self.return_state:
212      output_shape = [output_shape]
213      if cell.data_format == 'channels_first':
214        output_shape += [(input_shape[0], cell.filters, rows, cols)
215                         for _ in range(2)]
216      elif cell.data_format == 'channels_last':
217        output_shape += [(input_shape[0], rows, cols, cell.filters)
218                         for _ in range(2)]
219    return output_shape
220
221  @tf_utils.shape_type_conversion
222  def build(self, input_shape):
223    # Note input_shape will be list of shapes of initial states and
224    # constants if these are passed in __call__.
225    if self._num_constants is not None:
226      constants_shape = input_shape[-self._num_constants:]  # pylint: disable=E1130
227    else:
228      constants_shape = None
229
230    if isinstance(input_shape, list):
231      input_shape = input_shape[0]
232
233    batch_size = input_shape[0] if self.stateful else None
234    self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5])
235
236    # allow cell (if layer) to build before we set or validate state_spec
237    if isinstance(self.cell, Layer):
238      step_input_shape = (input_shape[0],) + input_shape[2:]
239      if constants_shape is not None:
240        self.cell.build([step_input_shape] + constants_shape)
241      else:
242        self.cell.build(step_input_shape)
243
244    # set or validate state_spec
245    if hasattr(self.cell.state_size, '__len__'):
246      state_size = list(self.cell.state_size)
247    else:
248      state_size = [self.cell.state_size]
249
250    if self.state_spec is not None:
251      # initial_state was passed in call, check compatibility
252      if self.cell.data_format == 'channels_first':
253        ch_dim = 1
254      elif self.cell.data_format == 'channels_last':
255        ch_dim = 3
256      if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
257        raise ValueError(
258            'An initial_state was passed that is not compatible with '
259            '`cell.state_size`. Received `state_spec`={}; '
260            'However `cell.state_size` is '
261            '{}'.format([spec.shape for spec in self.state_spec],
262                        self.cell.state_size))
263    else:
264      if self.cell.data_format == 'channels_first':
265        self.state_spec = [InputSpec(shape=(None, dim, None, None))
266                           for dim in state_size]
267      elif self.cell.data_format == 'channels_last':
268        self.state_spec = [InputSpec(shape=(None, None, None, dim))
269                           for dim in state_size]
270    if self.stateful:
271      self.reset_states()
272    self.built = True
273
274  def get_initial_state(self, inputs):
275    # (samples, timesteps, rows, cols, filters)
276    initial_state = backend.zeros_like(inputs)
277    # (samples, rows, cols, filters)
278    initial_state = backend.sum(initial_state, axis=1)
279    shape = list(self.cell.kernel_shape)
280    shape[-1] = self.cell.filters
281    initial_state = self.cell.input_conv(initial_state,
282                                         array_ops.zeros(tuple(shape),
283                                                         initial_state.dtype),
284                                         padding=self.cell.padding)
285
286    if hasattr(self.cell.state_size, '__len__'):
287      return [initial_state for _ in self.cell.state_size]
288    else:
289      return [initial_state]
290
291  def call(self,
292           inputs,
293           mask=None,
294           training=None,
295           initial_state=None,
296           constants=None):
297    # note that the .build() method of subclasses MUST define
298    # self.input_spec and self.state_spec with complete input shapes.
299    inputs, initial_state, constants = self._process_inputs(
300        inputs, initial_state, constants)
301
302    if isinstance(mask, list):
303      mask = mask[0]
304    timesteps = backend.int_shape(inputs)[1]
305
306    kwargs = {}
307    if generic_utils.has_arg(self.cell.call, 'training'):
308      kwargs['training'] = training
309
310    if constants:
311      if not generic_utils.has_arg(self.cell.call, 'constants'):
312        raise ValueError('RNN cell does not support constants')
313
314      def step(inputs, states):
315        constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
316        states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
317        return self.cell.call(inputs, states, constants=constants, **kwargs)
318    else:
319      def step(inputs, states):
320        return self.cell.call(inputs, states, **kwargs)
321
322    last_output, outputs, states = backend.rnn(step,
323                                               inputs,
324                                               initial_state,
325                                               constants=constants,
326                                               go_backwards=self.go_backwards,
327                                               mask=mask,
328                                               input_length=timesteps)
329    if self.stateful:
330      updates = [
331          backend.update(self_state, state)
332          for self_state, state in zip(self.states, states)
333      ]
334      self.add_update(updates)
335
336    if self.return_sequences:
337      output = outputs
338    else:
339      output = last_output
340
341    if self.return_state:
342      if not isinstance(states, (list, tuple)):
343        states = [states]
344      else:
345        states = list(states)
346      return [output] + states
347    else:
348      return output
349
350  def reset_states(self, states=None):
351    if not self.stateful:
352      raise AttributeError('Layer must be stateful.')
353    input_shape = self.input_spec[0].shape
354    state_shape = self.compute_output_shape(input_shape)
355    if self.return_state:
356      state_shape = state_shape[0]
357    if self.return_sequences:
358      state_shape = state_shape[:1].concatenate(state_shape[2:])
359    if None in state_shape:
360      raise ValueError('If a RNN is stateful, it needs to know '
361                       'its batch size. Specify the batch size '
362                       'of your input tensors: \n'
363                       '- If using a Sequential model, '
364                       'specify the batch size by passing '
365                       'a `batch_input_shape` '
366                       'argument to your first layer.\n'
367                       '- If using the functional API, specify '
368                       'the time dimension by passing a '
369                       '`batch_shape` argument to your Input layer.\n'
370                       'The same thing goes for the number of rows and '
371                       'columns.')
372
373    # helper function
374    def get_tuple_shape(nb_channels):
375      result = list(state_shape)
376      if self.cell.data_format == 'channels_first':
377        result[1] = nb_channels
378      elif self.cell.data_format == 'channels_last':
379        result[3] = nb_channels
380      else:
381        raise KeyError
382      return tuple(result)
383
384    # initialize state if None
385    if self.states[0] is None:
386      if hasattr(self.cell.state_size, '__len__'):
387        self.states = [backend.zeros(get_tuple_shape(dim))
388                       for dim in self.cell.state_size]
389      else:
390        self.states = [backend.zeros(get_tuple_shape(self.cell.state_size))]
391    elif states is None:
392      if hasattr(self.cell.state_size, '__len__'):
393        for state, dim in zip(self.states, self.cell.state_size):
394          backend.set_value(state, np.zeros(get_tuple_shape(dim)))
395      else:
396        backend.set_value(self.states[0],
397                          np.zeros(get_tuple_shape(self.cell.state_size)))
398    else:
399      if not isinstance(states, (list, tuple)):
400        states = [states]
401      if len(states) != len(self.states):
402        raise ValueError('Layer ' + self.name + ' expects ' +
403                         str(len(self.states)) + ' states, ' +
404                         'but it received ' + str(len(states)) +
405                         ' state values. Input received: ' + str(states))
406      for index, (value, state) in enumerate(zip(states, self.states)):
407        if hasattr(self.cell.state_size, '__len__'):
408          dim = self.cell.state_size[index]
409        else:
410          dim = self.cell.state_size
411        if value.shape != get_tuple_shape(dim):
412          raise ValueError('State ' + str(index) +
413                           ' is incompatible with layer ' +
414                           self.name + ': expected shape=' +
415                           str(get_tuple_shape(dim)) +
416                           ', found shape=' + str(value.shape))
417        # TODO(anjalisridhar): consider batch calls to `set_value`.
418        backend.set_value(state, value)
419
420
421class ConvLSTM2DCell(DropoutRNNCellMixin, Layer):
422  """Cell class for the ConvLSTM2D layer.
423
424  Args:
425    filters: Integer, the dimensionality of the output space
426      (i.e. the number of output filters in the convolution).
427    kernel_size: An integer or tuple/list of n integers, specifying the
428      dimensions of the convolution window.
429    strides: An integer or tuple/list of n integers,
430      specifying the strides of the convolution.
431      Specifying any stride value != 1 is incompatible with specifying
432      any `dilation_rate` value != 1.
433    padding: One of `"valid"` or `"same"` (case-insensitive).
434      `"valid"` means no padding. `"same"` results in padding evenly to
435      the left/right or up/down of the input such that output has the same
436      height/width dimension as the input.
437    data_format: A string,
438      one of `channels_last` (default) or `channels_first`.
439      It defaults to the `image_data_format` value found in your
440      Keras config file at `~/.keras/keras.json`.
441      If you never set it, then it will be "channels_last".
442    dilation_rate: An integer or tuple/list of n integers, specifying
443      the dilation rate to use for dilated convolution.
444      Currently, specifying any `dilation_rate` value != 1 is
445      incompatible with specifying any `strides` value != 1.
446    activation: Activation function to use.
447      If you don't specify anything, no activation is applied
448      (ie. "linear" activation: `a(x) = x`).
449    recurrent_activation: Activation function to use
450      for the recurrent step.
451    use_bias: Boolean, whether the layer uses a bias vector.
452    kernel_initializer: Initializer for the `kernel` weights matrix,
453      used for the linear transformation of the inputs.
454    recurrent_initializer: Initializer for the `recurrent_kernel`
455      weights matrix,
456      used for the linear transformation of the recurrent state.
457    bias_initializer: Initializer for the bias vector.
458    unit_forget_bias: Boolean.
459      If True, add 1 to the bias of the forget gate at initialization.
460      Use in combination with `bias_initializer="zeros"`.
461      This is recommended in [Jozefowicz et al., 2015](
462        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
463    kernel_regularizer: Regularizer function applied to
464      the `kernel` weights matrix.
465    recurrent_regularizer: Regularizer function applied to
466      the `recurrent_kernel` weights matrix.
467    bias_regularizer: Regularizer function applied to the bias vector.
468    kernel_constraint: Constraint function applied to
469      the `kernel` weights matrix.
470    recurrent_constraint: Constraint function applied to
471      the `recurrent_kernel` weights matrix.
472    bias_constraint: Constraint function applied to the bias vector.
473    dropout: Float between 0 and 1.
474      Fraction of the units to drop for
475      the linear transformation of the inputs.
476    recurrent_dropout: Float between 0 and 1.
477      Fraction of the units to drop for
478      the linear transformation of the recurrent state.
479
480  Call arguments:
481    inputs: A 4D tensor.
482    states:  List of state tensors corresponding to the previous timestep.
483    training: Python boolean indicating whether the layer should behave in
484      training mode or in inference mode. Only relevant when `dropout` or
485      `recurrent_dropout` is used.
486  """
487
488  def __init__(self,
489               filters,
490               kernel_size,
491               strides=(1, 1),
492               padding='valid',
493               data_format=None,
494               dilation_rate=(1, 1),
495               activation='tanh',
496               recurrent_activation='hard_sigmoid',
497               use_bias=True,
498               kernel_initializer='glorot_uniform',
499               recurrent_initializer='orthogonal',
500               bias_initializer='zeros',
501               unit_forget_bias=True,
502               kernel_regularizer=None,
503               recurrent_regularizer=None,
504               bias_regularizer=None,
505               kernel_constraint=None,
506               recurrent_constraint=None,
507               bias_constraint=None,
508               dropout=0.,
509               recurrent_dropout=0.,
510               **kwargs):
511    super(ConvLSTM2DCell, self).__init__(**kwargs)
512    self.filters = filters
513    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
514    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
515    self.padding = conv_utils.normalize_padding(padding)
516    self.data_format = conv_utils.normalize_data_format(data_format)
517    self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
518                                                    'dilation_rate')
519    self.activation = activations.get(activation)
520    self.recurrent_activation = activations.get(recurrent_activation)
521    self.use_bias = use_bias
522
523    self.kernel_initializer = initializers.get(kernel_initializer)
524    self.recurrent_initializer = initializers.get(recurrent_initializer)
525    self.bias_initializer = initializers.get(bias_initializer)
526    self.unit_forget_bias = unit_forget_bias
527
528    self.kernel_regularizer = regularizers.get(kernel_regularizer)
529    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
530    self.bias_regularizer = regularizers.get(bias_regularizer)
531
532    self.kernel_constraint = constraints.get(kernel_constraint)
533    self.recurrent_constraint = constraints.get(recurrent_constraint)
534    self.bias_constraint = constraints.get(bias_constraint)
535
536    self.dropout = min(1., max(0., dropout))
537    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
538    self.state_size = (self.filters, self.filters)
539
540  def build(self, input_shape):
541
542    if self.data_format == 'channels_first':
543      channel_axis = 1
544    else:
545      channel_axis = -1
546    if input_shape[channel_axis] is None:
547      raise ValueError('The channel dimension of the inputs '
548                       'should be defined. Found `None`.')
549    input_dim = input_shape[channel_axis]
550    kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
551    self.kernel_shape = kernel_shape
552    recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
553
554    self.kernel = self.add_weight(shape=kernel_shape,
555                                  initializer=self.kernel_initializer,
556                                  name='kernel',
557                                  regularizer=self.kernel_regularizer,
558                                  constraint=self.kernel_constraint)
559    self.recurrent_kernel = self.add_weight(
560        shape=recurrent_kernel_shape,
561        initializer=self.recurrent_initializer,
562        name='recurrent_kernel',
563        regularizer=self.recurrent_regularizer,
564        constraint=self.recurrent_constraint)
565
566    if self.use_bias:
567      if self.unit_forget_bias:
568
569        def bias_initializer(_, *args, **kwargs):
570          return backend.concatenate([
571              self.bias_initializer((self.filters,), *args, **kwargs),
572              initializers.get('ones')((self.filters,), *args, **kwargs),
573              self.bias_initializer((self.filters * 2,), *args, **kwargs),
574          ])
575      else:
576        bias_initializer = self.bias_initializer
577      self.bias = self.add_weight(
578          shape=(self.filters * 4,),
579          name='bias',
580          initializer=bias_initializer,
581          regularizer=self.bias_regularizer,
582          constraint=self.bias_constraint)
583    else:
584      self.bias = None
585    self.built = True
586
587  def call(self, inputs, states, training=None):
588    h_tm1 = states[0]  # previous memory state
589    c_tm1 = states[1]  # previous carry state
590
591    # dropout matrices for input units
592    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
593    # dropout matrices for recurrent units
594    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
595        h_tm1, training, count=4)
596
597    if 0 < self.dropout < 1.:
598      inputs_i = inputs * dp_mask[0]
599      inputs_f = inputs * dp_mask[1]
600      inputs_c = inputs * dp_mask[2]
601      inputs_o = inputs * dp_mask[3]
602    else:
603      inputs_i = inputs
604      inputs_f = inputs
605      inputs_c = inputs
606      inputs_o = inputs
607
608    if 0 < self.recurrent_dropout < 1.:
609      h_tm1_i = h_tm1 * rec_dp_mask[0]
610      h_tm1_f = h_tm1 * rec_dp_mask[1]
611      h_tm1_c = h_tm1 * rec_dp_mask[2]
612      h_tm1_o = h_tm1 * rec_dp_mask[3]
613    else:
614      h_tm1_i = h_tm1
615      h_tm1_f = h_tm1
616      h_tm1_c = h_tm1
617      h_tm1_o = h_tm1
618
619    (kernel_i, kernel_f,
620     kernel_c, kernel_o) = array_ops.split(self.kernel, 4, axis=3)
621    (recurrent_kernel_i,
622     recurrent_kernel_f,
623     recurrent_kernel_c,
624     recurrent_kernel_o) = array_ops.split(self.recurrent_kernel, 4, axis=3)
625
626    if self.use_bias:
627      bias_i, bias_f, bias_c, bias_o = array_ops.split(self.bias, 4)
628    else:
629      bias_i, bias_f, bias_c, bias_o = None, None, None, None
630
631    x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
632    x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
633    x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
634    x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
635    h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
636    h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
637    h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
638    h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
639
640    i = self.recurrent_activation(x_i + h_i)
641    f = self.recurrent_activation(x_f + h_f)
642    c = f * c_tm1 + i * self.activation(x_c + h_c)
643    o = self.recurrent_activation(x_o + h_o)
644    h = o * self.activation(c)
645    return h, [h, c]
646
647  def input_conv(self, x, w, b=None, padding='valid'):
648    conv_out = backend.conv2d(x, w, strides=self.strides,
649                              padding=padding,
650                              data_format=self.data_format,
651                              dilation_rate=self.dilation_rate)
652    if b is not None:
653      conv_out = backend.bias_add(conv_out, b,
654                                  data_format=self.data_format)
655    return conv_out
656
657  def recurrent_conv(self, x, w):
658    conv_out = backend.conv2d(x, w, strides=(1, 1),
659                              padding='same',
660                              data_format=self.data_format)
661    return conv_out
662
663  def get_config(self):
664    config = {'filters': self.filters,
665              'kernel_size': self.kernel_size,
666              'strides': self.strides,
667              'padding': self.padding,
668              'data_format': self.data_format,
669              'dilation_rate': self.dilation_rate,
670              'activation': activations.serialize(self.activation),
671              'recurrent_activation': activations.serialize(
672                  self.recurrent_activation),
673              'use_bias': self.use_bias,
674              'kernel_initializer': initializers.serialize(
675                  self.kernel_initializer),
676              'recurrent_initializer': initializers.serialize(
677                  self.recurrent_initializer),
678              'bias_initializer': initializers.serialize(self.bias_initializer),
679              'unit_forget_bias': self.unit_forget_bias,
680              'kernel_regularizer': regularizers.serialize(
681                  self.kernel_regularizer),
682              'recurrent_regularizer': regularizers.serialize(
683                  self.recurrent_regularizer),
684              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
685              'kernel_constraint': constraints.serialize(
686                  self.kernel_constraint),
687              'recurrent_constraint': constraints.serialize(
688                  self.recurrent_constraint),
689              'bias_constraint': constraints.serialize(self.bias_constraint),
690              'dropout': self.dropout,
691              'recurrent_dropout': self.recurrent_dropout}
692    base_config = super(ConvLSTM2DCell, self).get_config()
693    return dict(list(base_config.items()) + list(config.items()))
694
695
696@keras_export('keras.layers.ConvLSTM2D')
697class ConvLSTM2D(ConvRNN2D):
698  """2D Convolutional LSTM layer.
699
700  A convolutional LSTM is similar to an LSTM, but the input transformations
701  and recurrent transformations are both convolutional. This layer is typically
702  used to process timeseries of images (i.e. video-like data).
703
704  It is known to perform well for weather data forecasting,
705  using inputs that are timeseries of 2D grids of sensor values.
706  It isn't usually applied to regular video data, due to its high computational
707  cost.
708
709  Args:
710    filters: Integer, the dimensionality of the output space
711      (i.e. the number of output filters in the convolution).
712    kernel_size: An integer or tuple/list of n integers, specifying the
713      dimensions of the convolution window.
714    strides: An integer or tuple/list of n integers,
715      specifying the strides of the convolution.
716      Specifying any stride value != 1 is incompatible with specifying
717      any `dilation_rate` value != 1.
718    padding: One of `"valid"` or `"same"` (case-insensitive).
719      `"valid"` means no padding. `"same"` results in padding evenly to
720      the left/right or up/down of the input such that output has the same
721      height/width dimension as the input.
722    data_format: A string,
723      one of `channels_last` (default) or `channels_first`.
724      The ordering of the dimensions in the inputs.
725      `channels_last` corresponds to inputs with shape
726      `(batch, time, ..., channels)`
727      while `channels_first` corresponds to
728      inputs with shape `(batch, time, channels, ...)`.
729      It defaults to the `image_data_format` value found in your
730      Keras config file at `~/.keras/keras.json`.
731      If you never set it, then it will be "channels_last".
732    dilation_rate: An integer or tuple/list of n integers, specifying
733      the dilation rate to use for dilated convolution.
734      Currently, specifying any `dilation_rate` value != 1 is
735      incompatible with specifying any `strides` value != 1.
736    activation: Activation function to use.
737      By default hyperbolic tangent activation function is applied
738      (`tanh(x)`).
739    recurrent_activation: Activation function to use
740      for the recurrent step.
741    use_bias: Boolean, whether the layer uses a bias vector.
742    kernel_initializer: Initializer for the `kernel` weights matrix,
743      used for the linear transformation of the inputs.
744    recurrent_initializer: Initializer for the `recurrent_kernel`
745      weights matrix,
746      used for the linear transformation of the recurrent state.
747    bias_initializer: Initializer for the bias vector.
748    unit_forget_bias: Boolean.
749      If True, add 1 to the bias of the forget gate at initialization.
750      Use in combination with `bias_initializer="zeros"`.
751      This is recommended in [Jozefowicz et al., 2015](
752        http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
753    kernel_regularizer: Regularizer function applied to
754      the `kernel` weights matrix.
755    recurrent_regularizer: Regularizer function applied to
756      the `recurrent_kernel` weights matrix.
757    bias_regularizer: Regularizer function applied to the bias vector.
758    activity_regularizer: Regularizer function applied to.
759    kernel_constraint: Constraint function applied to
760      the `kernel` weights matrix.
761    recurrent_constraint: Constraint function applied to
762      the `recurrent_kernel` weights matrix.
763    bias_constraint: Constraint function applied to the bias vector.
764    return_sequences: Boolean. Whether to return the last output
765      in the output sequence, or the full sequence. (default False)
766    return_state: Boolean Whether to return the last state
767      in addition to the output. (default False)
768    go_backwards: Boolean (default False).
769      If True, process the input sequence backwards.
770    stateful: Boolean (default False). If True, the last state
771      for each sample at index i in a batch will be used as initial
772      state for the sample of index i in the following batch.
773    dropout: Float between 0 and 1.
774      Fraction of the units to drop for
775      the linear transformation of the inputs.
776    recurrent_dropout: Float between 0 and 1.
777      Fraction of the units to drop for
778      the linear transformation of the recurrent state.
779
780  Call arguments:
781    inputs: A 5D float tensor (see input shape description below).
782    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
783      a given timestep should be masked.
784    training: Python boolean indicating whether the layer should behave in
785      training mode or in inference mode. This argument is passed to the cell
786      when calling it. This is only relevant if `dropout` or `recurrent_dropout`
787      are set.
788    initial_state: List of initial state tensors to be passed to the first
789      call of the cell.
790
791  Input shape:
792    - If data_format='channels_first'
793        5D tensor with shape:
794        `(samples, time, channels, rows, cols)`
795    - If data_format='channels_last'
796        5D tensor with shape:
797        `(samples, time, rows, cols, channels)`
798
799  Output shape:
800    - If `return_state`: a list of tensors. The first tensor is
801      the output. The remaining tensors are the last states,
802      each 4D tensor with shape:
803      `(samples, filters, new_rows, new_cols)`
804      if data_format='channels_first'
805      or 4D tensor with shape:
806      `(samples, new_rows, new_cols, filters)`
807      if data_format='channels_last'.
808      `rows` and `cols` values might have changed due to padding.
809    - If `return_sequences`: 5D tensor with shape:
810      `(samples, timesteps, filters, new_rows, new_cols)`
811      if data_format='channels_first'
812      or 5D tensor with shape:
813      `(samples, timesteps, new_rows, new_cols, filters)`
814      if data_format='channels_last'.
815    - Else, 4D tensor with shape:
816      `(samples, filters, new_rows, new_cols)`
817      if data_format='channels_first'
818      or 4D tensor with shape:
819      `(samples, new_rows, new_cols, filters)`
820      if data_format='channels_last'.
821
822  Raises:
823    ValueError: in case of invalid constructor arguments.
824
825  References:
826    - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)
827    (the current implementation does not include the feedback loop on the
828    cells output).
829
830  Example:
831
832  ```python
833  steps = 10
834  height = 32
835  width = 32
836  input_channels = 3
837  output_channels = 6
838
839  inputs = tf.keras.Input(shape=(steps, height, width, input_channels))
840  layer = tf.keras.layers.ConvLSTM2D(filters=output_channels, kernel_size=3)
841  outputs = layer(inputs)
842  ```
843  """
844
845  def __init__(self,
846               filters,
847               kernel_size,
848               strides=(1, 1),
849               padding='valid',
850               data_format=None,
851               dilation_rate=(1, 1),
852               activation='tanh',
853               recurrent_activation='hard_sigmoid',
854               use_bias=True,
855               kernel_initializer='glorot_uniform',
856               recurrent_initializer='orthogonal',
857               bias_initializer='zeros',
858               unit_forget_bias=True,
859               kernel_regularizer=None,
860               recurrent_regularizer=None,
861               bias_regularizer=None,
862               activity_regularizer=None,
863               kernel_constraint=None,
864               recurrent_constraint=None,
865               bias_constraint=None,
866               return_sequences=False,
867               return_state=False,
868               go_backwards=False,
869               stateful=False,
870               dropout=0.,
871               recurrent_dropout=0.,
872               **kwargs):
873    cell = ConvLSTM2DCell(filters=filters,
874                          kernel_size=kernel_size,
875                          strides=strides,
876                          padding=padding,
877                          data_format=data_format,
878                          dilation_rate=dilation_rate,
879                          activation=activation,
880                          recurrent_activation=recurrent_activation,
881                          use_bias=use_bias,
882                          kernel_initializer=kernel_initializer,
883                          recurrent_initializer=recurrent_initializer,
884                          bias_initializer=bias_initializer,
885                          unit_forget_bias=unit_forget_bias,
886                          kernel_regularizer=kernel_regularizer,
887                          recurrent_regularizer=recurrent_regularizer,
888                          bias_regularizer=bias_regularizer,
889                          kernel_constraint=kernel_constraint,
890                          recurrent_constraint=recurrent_constraint,
891                          bias_constraint=bias_constraint,
892                          dropout=dropout,
893                          recurrent_dropout=recurrent_dropout,
894                          dtype=kwargs.get('dtype'))
895    super(ConvLSTM2D, self).__init__(cell,
896                                     return_sequences=return_sequences,
897                                     return_state=return_state,
898                                     go_backwards=go_backwards,
899                                     stateful=stateful,
900                                     **kwargs)
901    self.activity_regularizer = regularizers.get(activity_regularizer)
902
903  def call(self, inputs, mask=None, training=None, initial_state=None):
904    return super(ConvLSTM2D, self).call(inputs,
905                                        mask=mask,
906                                        training=training,
907                                        initial_state=initial_state)
908
909  @property
910  def filters(self):
911    return self.cell.filters
912
913  @property
914  def kernel_size(self):
915    return self.cell.kernel_size
916
917  @property
918  def strides(self):
919    return self.cell.strides
920
921  @property
922  def padding(self):
923    return self.cell.padding
924
925  @property
926  def data_format(self):
927    return self.cell.data_format
928
929  @property
930  def dilation_rate(self):
931    return self.cell.dilation_rate
932
933  @property
934  def activation(self):
935    return self.cell.activation
936
937  @property
938  def recurrent_activation(self):
939    return self.cell.recurrent_activation
940
941  @property
942  def use_bias(self):
943    return self.cell.use_bias
944
945  @property
946  def kernel_initializer(self):
947    return self.cell.kernel_initializer
948
949  @property
950  def recurrent_initializer(self):
951    return self.cell.recurrent_initializer
952
953  @property
954  def bias_initializer(self):
955    return self.cell.bias_initializer
956
957  @property
958  def unit_forget_bias(self):
959    return self.cell.unit_forget_bias
960
961  @property
962  def kernel_regularizer(self):
963    return self.cell.kernel_regularizer
964
965  @property
966  def recurrent_regularizer(self):
967    return self.cell.recurrent_regularizer
968
969  @property
970  def bias_regularizer(self):
971    return self.cell.bias_regularizer
972
973  @property
974  def kernel_constraint(self):
975    return self.cell.kernel_constraint
976
977  @property
978  def recurrent_constraint(self):
979    return self.cell.recurrent_constraint
980
981  @property
982  def bias_constraint(self):
983    return self.cell.bias_constraint
984
985  @property
986  def dropout(self):
987    return self.cell.dropout
988
989  @property
990  def recurrent_dropout(self):
991    return self.cell.recurrent_dropout
992
993  def get_config(self):
994    config = {'filters': self.filters,
995              'kernel_size': self.kernel_size,
996              'strides': self.strides,
997              'padding': self.padding,
998              'data_format': self.data_format,
999              'dilation_rate': self.dilation_rate,
1000              'activation': activations.serialize(self.activation),
1001              'recurrent_activation': activations.serialize(
1002                  self.recurrent_activation),
1003              'use_bias': self.use_bias,
1004              'kernel_initializer': initializers.serialize(
1005                  self.kernel_initializer),
1006              'recurrent_initializer': initializers.serialize(
1007                  self.recurrent_initializer),
1008              'bias_initializer': initializers.serialize(self.bias_initializer),
1009              'unit_forget_bias': self.unit_forget_bias,
1010              'kernel_regularizer': regularizers.serialize(
1011                  self.kernel_regularizer),
1012              'recurrent_regularizer': regularizers.serialize(
1013                  self.recurrent_regularizer),
1014              'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1015              'activity_regularizer': regularizers.serialize(
1016                  self.activity_regularizer),
1017              'kernel_constraint': constraints.serialize(
1018                  self.kernel_constraint),
1019              'recurrent_constraint': constraints.serialize(
1020                  self.recurrent_constraint),
1021              'bias_constraint': constraints.serialize(self.bias_constraint),
1022              'dropout': self.dropout,
1023              'recurrent_dropout': self.recurrent_dropout}
1024    base_config = super(ConvLSTM2D, self).get_config()
1025    del base_config['cell']
1026    return dict(list(base_config.items()) + list(config.items()))
1027
1028  @classmethod
1029  def from_config(cls, config):
1030    return cls(**config)
1031