• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Recurrent layers and their base classes.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24import numpy as np
25
26from tensorflow.python.eager import context
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.keras import activations
29from tensorflow.python.keras import backend as K
30from tensorflow.python.keras import constraints
31from tensorflow.python.keras import initializers
32from tensorflow.python.keras import regularizers
33from tensorflow.python.keras.engine.base_layer import Layer
34from tensorflow.python.keras.engine.input_spec import InputSpec
35from tensorflow.python.keras.utils import generic_utils
36from tensorflow.python.keras.utils import tf_utils
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import state_ops
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.training.tracking import base as trackable
41from tensorflow.python.util import nest
42from tensorflow.python.util.tf_export import keras_export
43
44
45@keras_export('keras.layers.StackedRNNCells')
46class StackedRNNCells(Layer):
47  """Wrapper allowing a stack of RNN cells to behave as a single cell.
48
49  Used to implement efficient stacked RNNs.
50
51  Arguments:
52    cells: List of RNN cell instances.
53
54  Examples:
55
56  ```python
57  cells = [
58      keras.layers.LSTMCell(output_dim),
59      keras.layers.LSTMCell(output_dim),
60      keras.layers.LSTMCell(output_dim),
61  ]
62
63  inputs = keras.Input((timesteps, input_dim))
64  x = keras.layers.RNN(cells)(inputs)
65  ```
66  """
67
68  def __init__(self, cells, **kwargs):
69    for cell in cells:
70      if not hasattr(cell, 'call'):
71        raise ValueError('All cells must have a `call` method. '
72                         'received cells:', cells)
73      if not hasattr(cell, 'state_size'):
74        raise ValueError('All cells must have a '
75                         '`state_size` attribute. '
76                         'received cells:', cells)
77    self.cells = cells
78    # reverse_state_order determines whether the state size will be in a reverse
79    # order of the cells' state. User might want to set this to True to keep the
80    # existing behavior. This is only useful when use RNN(return_state=True)
81    # since the state will be returned as the same order of state_size.
82    self.reverse_state_order = kwargs.pop('reverse_state_order', False)
83    if self.reverse_state_order:
84      logging.warning('reverse_state_order=True in StackedRNNCells will soon '
85                      'be deprecated. Please update the code to work with the '
86                      'natural order of states if you reply on the RNN states, '
87                      'eg RNN(return_state=True).')
88    super(StackedRNNCells, self).__init__(**kwargs)
89
90  @property
91  def state_size(self):
92    return tuple(c.state_size for c in
93                 (self.cells[::-1] if self.reverse_state_order else self.cells))
94
95  @property
96  def output_size(self):
97    if getattr(self.cells[-1], 'output_size', None) is not None:
98      return self.cells[-1].output_size
99    elif _is_multiple_state(self.cells[-1].state_size):
100      return self.cells[-1].state_size[0]
101    else:
102      return self.cells[-1].state_size
103
104  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
105    initial_states = []
106    for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
107      get_initial_state_fn = getattr(cell, 'get_initial_state', None)
108      if get_initial_state_fn:
109        initial_states.append(get_initial_state_fn(
110            inputs=inputs, batch_size=batch_size, dtype=dtype))
111      else:
112        initial_states.append(_generate_zero_filled_state_for_cell(
113            cell, inputs, batch_size, dtype))
114
115    return tuple(initial_states)
116
117  def call(self, inputs, states, constants=None, **kwargs):
118    # Recover per-cell states.
119    state_size = (self.state_size[::-1]
120                  if self.reverse_state_order else self.state_size)
121    nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))
122
123    # Call the cells in order and store the returned states.
124    new_nested_states = []
125    for cell, states in zip(self.cells, nested_states):
126      states = states if nest.is_sequence(states) else [states]
127      # TF cell does not wrap the state into list when there is only one state.
128      is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
129      states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
130      if generic_utils.has_arg(cell.call, 'constants'):
131        inputs, states = cell.call(inputs, states, constants=constants,
132                                   **kwargs)
133      else:
134        inputs, states = cell.call(inputs, states, **kwargs)
135      new_nested_states.append(states)
136
137    return inputs, nest.pack_sequence_as(state_size,
138                                         nest.flatten(new_nested_states))
139
140  @tf_utils.shape_type_conversion
141  def build(self, input_shape):
142    if isinstance(input_shape, list):
143      constants_shape = input_shape[1:]
144      input_shape = input_shape[0]
145    for cell in self.cells:
146      if isinstance(cell, Layer):
147        if generic_utils.has_arg(cell.call, 'constants'):
148          cell.build([input_shape] + constants_shape)
149        else:
150          cell.build(input_shape)
151      if getattr(cell, 'output_size', None) is not None:
152        output_dim = cell.output_size
153      elif _is_multiple_state(cell.state_size):
154        output_dim = cell.state_size[0]
155      else:
156        output_dim = cell.state_size
157      input_shape = tuple([input_shape[0]] +
158                          tensor_shape.as_shape(output_dim).as_list())
159    self.built = True
160
161  def get_config(self):
162    cells = []
163    for cell in self.cells:
164      cells.append({
165          'class_name': cell.__class__.__name__,
166          'config': cell.get_config()
167      })
168    config = {'cells': cells}
169    base_config = super(StackedRNNCells, self).get_config()
170    return dict(list(base_config.items()) + list(config.items()))
171
172  @classmethod
173  def from_config(cls, config, custom_objects=None):
174    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
175    cells = []
176    for cell_config in config.pop('cells'):
177      cells.append(
178          deserialize_layer(cell_config, custom_objects=custom_objects))
179    return cls(cells, **config)
180
181
182@keras_export('keras.layers.RNN')
183class RNN(Layer):
184  """Base class for recurrent layers.
185
186  Arguments:
187    cell: A RNN cell instance or a list of RNN cell instances.
188      A RNN cell is a class that has:
189      - A `call(input_at_t, states_at_t)` method, returning
190        `(output_at_t, states_at_t_plus_1)`. The call method of the
191        cell can also take the optional argument `constants`, see
192        section "Note on passing external constants" below.
193      - A `state_size` attribute. This can be a single integer
194        (single state) in which case it is the size of the recurrent
195        state. This can also be a list/tuple of integers (one size per
196        state).
197        The `state_size` can also be TensorShape or tuple/list of
198        TensorShape, to represent high dimension state.
199      - A `output_size` attribute. This can be a single integer or a
200        TensorShape, which represent the shape of the output. For backward
201        compatible reason, if this attribute is not available for the
202        cell, the value will be inferred by the first element of the
203        `state_size`.
204      - A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
205        method that creates a tensor meant to be fed to `call()` as the
206        initial state, if user didn't specify any initial state via other
207        means. The returned initial state should be in shape of
208        [batch, cell.state_size]. Cell might choose to create zero filled
209        tensor, or with other values based on the cell implementations.
210        `inputs` is the input tensor to the RNN layer, which should
211        contain the batch size as its shape[0], and also dtype. Note that
212        the shape[0] might be None during the graph construction. Either
213        the `inputs` or the pair of `batch` and `dtype `are provided.
214        `batch` is a scalar tensor that represent the batch size
215        of the input. `dtype` is `tf.dtype` that represent the dtype of
216        the input.
217        For backward compatible reason, if this method is not implemented
218        by the cell, RNN layer will create a zero filled tensors with the
219        size of [batch, cell.state_size].
220      In the case that `cell` is a list of RNN cell instances, the cells
221      will be stacked on after the other in the RNN, implementing an
222      efficient stacked RNN.
223    return_sequences: Boolean. Whether to return the last output
224      in the output sequence, or the full sequence.
225    return_state: Boolean. Whether to return the last state
226      in addition to the output.
227    go_backwards: Boolean (default False).
228      If True, process the input sequence backwards and return the
229      reversed sequence.
230    stateful: Boolean (default False). If True, the last state
231      for each sample at index i in a batch will be used as initial
232      state for the sample of index i in the following batch.
233    unroll: Boolean (default False).
234      If True, the network will be unrolled, else a symbolic loop will be used.
235      Unrolling can speed-up a RNN,
236      although it tends to be more memory-intensive.
237      Unrolling is only suitable for short sequences.
238    time_major: The shape format of the `inputs` and `outputs` tensors.
239        If True, the inputs and outputs will be in shape
240        `(timesteps, batch, ...)`, whereas in the False case, it will be
241        `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
242        efficient because it avoids transposes at the beginning and end of the
243        RNN calculation. However, most TensorFlow data is batch-major, so by
244        default this function accepts input and emits output in batch-major
245        form.
246
247  Call arguments:
248    inputs: Input tensor.
249    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
250      a given timestep should be masked.
251    training: Python boolean indicating whether the layer should behave in
252      training mode or in inference mode. This argument is passed to the cell
253      when calling it. This is for use with cells that use dropout.
254    initial_state: List of initial state tensors to be passed to the first
255      call of the cell.
256    constants: List of constant tensors to be passed to the cell at each
257      timestep.
258
259  Input shape:
260    N-D tensor with shape `(batch_size, timesteps, ...)` or
261    `(timesteps, batch_size, ...)` when time_major is True.
262
263  Output shape:
264    - If `return_state`: a list of tensors. The first tensor is
265      the output. The remaining tensors are the last states,
266      each with shape `(batch_size, state_size)`, where `state_size` could
267      be a high dimension tensor shape.
268    - If `return_sequences`: N-D tensor with shape
269      `(batch_size, timesteps, output_size)`, where `output_size` could
270      be a high dimension tensor shape, or
271      `(timesteps, batch_size, output_size)` when `time_major` is True.
272    - Else, N-D tensor with shape `(batch_size, output_size)`, where
273      `output_size` could be a high dimension tensor shape.
274
275  Masking:
276    This layer supports masking for input data with a variable number
277    of timesteps. To introduce masks to your data,
278    use an [Embedding](embeddings.md) layer with the `mask_zero` parameter
279    set to `True`.
280
281  Note on using statefulness in RNNs:
282    You can set RNN layers to be 'stateful', which means that the states
283    computed for the samples in one batch will be reused as initial states
284    for the samples in the next batch. This assumes a one-to-one mapping
285    between samples in different successive batches.
286
287    To enable statefulness:
288      - Specify `stateful=True` in the layer constructor.
289      - Specify a fixed batch size for your model, by passing
290        If sequential model:
291          `batch_input_shape=(...)` to the first layer in your model.
292        Else for functional model with 1 or more Input layers:
293          `batch_shape=(...)` to all the first layers in your model.
294        This is the expected shape of your inputs
295        *including the batch size*.
296        It should be a tuple of integers, e.g. `(32, 10, 100)`.
297      - Specify `shuffle=False` when calling fit().
298
299    To reset the states of your model, call `.reset_states()` on either
300    a specific layer, or on your entire model.
301
302  Note on specifying the initial state of RNNs:
303    You can specify the initial state of RNN layers symbolically by
304    calling them with the keyword argument `initial_state`. The value of
305    `initial_state` should be a tensor or list of tensors representing
306    the initial state of the RNN layer.
307
308    You can specify the initial state of RNN layers numerically by
309    calling `reset_states` with the keyword argument `states`. The value of
310    `states` should be a numpy array or list of numpy arrays representing
311    the initial state of the RNN layer.
312
313  Note on passing external constants to RNNs:
314    You can pass "external" constants to the cell using the `constants`
315    keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
316    requires that the `cell.call` method accepts the same keyword argument
317    `constants`. Such constants can be used to condition the cell
318    transformation on additional static inputs (not changing over time),
319    a.k.a. an attention mechanism.
320
321  Examples:
322
323  ```python
324  # First, let's define a RNN Cell, as a layer subclass.
325
326  class MinimalRNNCell(keras.layers.Layer):
327
328      def __init__(self, units, **kwargs):
329          self.units = units
330          self.state_size = units
331          super(MinimalRNNCell, self).__init__(**kwargs)
332
333      def build(self, input_shape):
334          self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
335                                        initializer='uniform',
336                                        name='kernel')
337          self.recurrent_kernel = self.add_weight(
338              shape=(self.units, self.units),
339              initializer='uniform',
340              name='recurrent_kernel')
341          self.built = True
342
343      def call(self, inputs, states):
344          prev_output = states[0]
345          h = K.dot(inputs, self.kernel)
346          output = h + K.dot(prev_output, self.recurrent_kernel)
347          return output, [output]
348
349  # Let's use this cell in a RNN layer:
350
351  cell = MinimalRNNCell(32)
352  x = keras.Input((None, 5))
353  layer = RNN(cell)
354  y = layer(x)
355
356  # Here's how to use the cell to build a stacked RNN:
357
358  cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
359  x = keras.Input((None, 5))
360  layer = RNN(cells)
361  y = layer(x)
362  ```
363  """
364
365  def __init__(self,
366               cell,
367               return_sequences=False,
368               return_state=False,
369               go_backwards=False,
370               stateful=False,
371               unroll=False,
372               time_major=False,
373               **kwargs):
374    if isinstance(cell, (list, tuple)):
375      cell = StackedRNNCells(cell)
376    if not hasattr(cell, 'call'):
377      raise ValueError('`cell` should have a `call` method. '
378                       'The RNN was passed:', cell)
379    if not hasattr(cell, 'state_size'):
380      raise ValueError('The RNN cell should have '
381                       'an attribute `state_size` '
382                       '(tuple of integers, '
383                       'one integer per RNN state).')
384    # If True, the output for masked timestep will be zeros, whereas in the
385    # False case, output from previous timestep is returned for masked timestep.
386    self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
387
388    if 'input_shape' not in kwargs and (
389        'input_dim' in kwargs or 'input_length' in kwargs):
390      input_shape = (kwargs.pop('input_length', None),
391                     kwargs.pop('input_dim', None))
392      kwargs['input_shape'] = input_shape
393
394    super(RNN, self).__init__(**kwargs)
395    self.cell = cell
396    self.return_sequences = return_sequences
397    self.return_state = return_state
398    self.go_backwards = go_backwards
399    self.stateful = stateful
400    self.unroll = unroll
401    self.time_major = time_major
402
403    self.supports_masking = True
404    # The input shape is unknown yet, it could have nested tensor inputs, and
405    # the input spec will be the list of specs for flattened inputs.
406    self.input_spec = None
407    self.state_spec = None
408    self._states = None
409    self.constants_spec = None
410    self._num_constants = None
411    self._num_inputs = None
412
413  @property
414  def states(self):
415    if self._states is None:
416      state = nest.map_structure(lambda _: None, self.cell.state_size)
417      return state if nest.is_sequence(self.cell.state_size) else [state]
418    return self._states
419
420  @states.setter
421  # Automatic tracking catches "self._states" which adds an extra weight and
422  # breaks HDF5 checkpoints.
423  @trackable.no_automatic_dependency_tracking
424  def states(self, states):
425    self._states = states
426
427  def compute_output_shape(self, input_shape):
428    if isinstance(input_shape, list):
429      input_shape = input_shape[0]
430    # Check whether the input shape contains any nested shapes. It could be
431    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
432    # inputs.
433    try:
434      input_shape = tensor_shape.as_shape(input_shape)
435    except (ValueError, TypeError):
436      # A nested tensor input
437      input_shape = nest.flatten(input_shape)[0]
438
439    batch = input_shape[0]
440    time_step = input_shape[1]
441    if self.time_major:
442      batch, time_step = time_step, batch
443
444    if _is_multiple_state(self.cell.state_size):
445      state_size = self.cell.state_size
446    else:
447      state_size = [self.cell.state_size]
448
449    def _get_output_shape(flat_output_size):
450      output_dim = tensor_shape.as_shape(flat_output_size).as_list()
451      if self.return_sequences:
452        if self.time_major:
453          output_shape = tensor_shape.as_shape([time_step, batch] + output_dim)
454        else:
455          output_shape = tensor_shape.as_shape([batch, time_step] + output_dim)
456      else:
457        output_shape = tensor_shape.as_shape([batch] + output_dim)
458      return output_shape
459
460    if getattr(self.cell, 'output_size', None) is not None:
461      # cell.output_size could be nested structure.
462      output_shape = nest.flatten(nest.map_structure(
463          _get_output_shape, self.cell.output_size))
464      output_shape = output_shape[0] if len(output_shape) == 1 else output_shape
465    else:
466      # Note that state_size[0] could be a tensor_shape or int.
467      output_shape = _get_output_shape(state_size[0])
468
469    if self.return_state:
470      def _get_state_shape(flat_state):
471        state_shape = [batch] + tensor_shape.as_shape(flat_state).as_list()
472        return tensor_shape.as_shape(state_shape)
473      state_shape = nest.map_structure(_get_state_shape, state_size)
474      return generic_utils.to_list(output_shape) + nest.flatten(state_shape)
475    else:
476      return output_shape
477
478  def compute_mask(self, inputs, mask):
479    # Time step masks must be the same for each input.
480    # This is because the mask for an RNN is of size [batch, time_steps, 1],
481    # and specifies which time steps should be skipped, and a time step
482    # must be skipped for all inputs.
483    # TODO(scottzhu): Should we accept multiple different masks?
484    mask = nest.flatten(mask)[0]
485    output_mask = mask if self.return_sequences else None
486    if self.return_state:
487      state_mask = [None for _ in self.states]
488      return [output_mask] + state_mask
489    else:
490      return output_mask
491
492  def build(self, input_shape):
493    # Note input_shape will be list of shapes of initial states and
494    # constants if these are passed in __call__.
495    if self._num_constants is not None:
496      constants_shape = input_shape[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
497      constants_shape = nest.map_structure(
498          lambda s: tuple(tensor_shape.TensorShape(s).as_list()),
499          constants_shape)
500    else:
501      constants_shape = None
502
503    if isinstance(input_shape, list):
504      input_shape = input_shape[0]
505      # The input_shape here could be a nest structure.
506
507    # do the tensor_shape to shapes here. The input could be single tensor, or a
508    # nested structure of tensors.
509    def get_input_spec(shape):
510      if isinstance(shape, tensor_shape.TensorShape):
511        input_spec_shape = shape.as_list()
512      else:
513        input_spec_shape = list(shape)
514      batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
515      if not self.stateful:
516        input_spec_shape[batch_index] = None
517      input_spec_shape[time_step_index] = None
518      return InputSpec(shape=tuple(input_spec_shape))
519
520    def get_step_input_shape(shape):
521      if isinstance(shape, tensor_shape.TensorShape):
522        shape = tuple(shape.as_list())
523      # remove the timestep from the input_shape
524      return shape[1:] if self.time_major else (shape[0],) + shape[2:]
525
526    # Check whether the input shape contains any nested shapes. It could be
527    # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
528    # inputs.
529    try:
530      input_shape = tensor_shape.as_shape(input_shape)
531    except (ValueError, TypeError):
532      # A nested tensor input
533      pass
534
535    if not nest.is_sequence(input_shape):
536      # This indicates the there is only one input.
537      if self.input_spec is not None:
538        self.input_spec[0] = get_input_spec(input_shape)
539      else:
540        self.input_spec = [get_input_spec(input_shape)]
541      step_input_shape = get_step_input_shape(input_shape)
542    else:
543      flat_input_shapes = nest.flatten(input_shape)
544      flat_input_shapes = nest.map_structure(get_input_spec, flat_input_shapes)
545      assert len(flat_input_shapes) == self._num_inputs
546      if self.input_spec is not None:
547        self.input_spec[:self._num_inputs] = flat_input_shapes
548      else:
549        self.input_spec = flat_input_shapes
550      step_input_shape = nest.map_structure(get_step_input_shape, input_shape)
551
552    # allow cell (if layer) to build before we set or validate state_spec
553    if isinstance(self.cell, Layer):
554      if constants_shape is not None:
555        self.cell.build([step_input_shape] + constants_shape)
556      else:
557        self.cell.build(step_input_shape)
558
559    # set or validate state_spec
560    if _is_multiple_state(self.cell.state_size):
561      state_size = list(self.cell.state_size)
562    else:
563      state_size = [self.cell.state_size]
564
565    if self.state_spec is not None:
566      # initial_state was passed in call, check compatibility
567      self._validate_state_spec(state_size, self.state_spec)
568    else:
569      self.state_spec = [
570          InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list())
571          for dim in state_size
572      ]
573    if self.stateful:
574      self.reset_states()
575    self.built = True
576
577  @staticmethod
578  def _validate_state_spec(cell_state_sizes, init_state_specs):
579    """Validate the state spec between the initial_state and the state_size.
580
581    Args:
582      cell_state_sizes: list, the `state_size` attribute from the cell.
583      init_state_specs: list, the `state_spec` from the initial_state that is
584        passed in `call()`.
585
586    Raises:
587      ValueError: When initial state spec is not compatible with the state size.
588    """
589    validation_error = ValueError(
590        'An `initial_state` was passed that is not compatible with '
591        '`cell.state_size`. Received `state_spec`={}; '
592        'however `cell.state_size` is '
593        '{}'.format(init_state_specs, cell_state_sizes))
594    if len(cell_state_sizes) == len(init_state_specs):
595      for i in range(len(cell_state_sizes)):
596        if not tensor_shape.TensorShape(
597            # Ignore the first axis for init_state which is for batch
598            init_state_specs[i].shape[1:]).is_compatible_with(
599                tensor_shape.TensorShape(cell_state_sizes[i])):
600          raise validation_error
601    else:
602      raise validation_error
603
604  def get_initial_state(self, inputs):
605    get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
606
607    if nest.is_sequence(inputs):
608      # The input are nested sequences. Use the first element in the seq to get
609      # batch size and dtype.
610      inputs = nest.flatten(inputs)[0]
611
612    input_shape = array_ops.shape(inputs)
613    batch_size = input_shape[1] if self.time_major else input_shape[0]
614    dtype = inputs.dtype
615    if get_initial_state_fn:
616      init_state = get_initial_state_fn(
617          inputs=None, batch_size=batch_size, dtype=dtype)
618    else:
619      init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
620                                               dtype)
621    # Keras RNN expect the states in a list, even if it's a single state tensor.
622    if not nest.is_sequence(init_state):
623      init_state = [init_state]
624    # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
625    return list(init_state)
626
627  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
628    inputs, initial_state, constants = _standardize_args(inputs,
629                                                         initial_state,
630                                                         constants,
631                                                         self._num_constants,
632                                                         self._num_inputs)
633    # in case the real inputs is a nested structure, set the size of flatten
634    # input so that we can distinguish between real inputs, initial_state and
635    # constants.
636    self._num_inputs = len(nest.flatten(inputs))
637
638    if initial_state is None and constants is None:
639      return super(RNN, self).__call__(inputs, **kwargs)
640
641    # If any of `initial_state` or `constants` are specified and are Keras
642    # tensors, then add them to the inputs and temporarily modify the
643    # input_spec to include them.
644
645    additional_inputs = []
646    additional_specs = []
647    if initial_state is not None:
648      additional_inputs += initial_state
649      self.state_spec = [
650          InputSpec(shape=K.int_shape(state)) for state in initial_state
651      ]
652      additional_specs += self.state_spec
653    if constants is not None:
654      additional_inputs += constants
655      self.constants_spec = [
656          InputSpec(shape=K.int_shape(constant)) for constant in constants
657      ]
658      self._num_constants = len(constants)
659      additional_specs += self.constants_spec
660    # at this point additional_inputs cannot be empty
661    is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
662    for tensor in additional_inputs:
663      if K.is_keras_tensor(tensor) != is_keras_tensor:
664        raise ValueError('The initial state or constants of an RNN'
665                         ' layer cannot be specified with a mix of'
666                         ' Keras tensors and non-Keras tensors'
667                         ' (a "Keras tensor" is a tensor that was'
668                         ' returned by a Keras layer, or by `Input`)')
669
670    if is_keras_tensor:
671      # Compute the full input spec, including state and constants
672      full_input = [inputs] + additional_inputs
673      # The original input_spec is None since there could be a nested tensor
674      # input. Update the input_spec to match the inputs.
675      full_input_spec = [None for _ in range(len(nest.flatten(inputs)))
676                        ] + additional_specs
677      # Perform the call with temporarily replaced input_spec
678      self.input_spec = full_input_spec
679      output = super(RNN, self).__call__(full_input, **kwargs)
680      # Remove the additional_specs from input spec and keep the rest. It is
681      # important to keep since the input spec was populated by build(), and
682      # will be reused in the stateful=True.
683      self.input_spec = self.input_spec[:-len(additional_specs)]
684      return output
685    else:
686      if initial_state is not None:
687        kwargs['initial_state'] = initial_state
688      if constants is not None:
689        kwargs['constants'] = constants
690      return super(RNN, self).__call__(inputs, **kwargs)
691
692  def call(self,
693           inputs,
694           mask=None,
695           training=None,
696           initial_state=None,
697           constants=None):
698    inputs, initial_state, constants = self._process_inputs(
699        inputs, initial_state, constants)
700
701    if mask is not None:
702      # Time step masks must be the same for each input.
703      # TODO(scottzhu): Should we accept multiple different masks?
704      mask = nest.flatten(mask)[0]
705
706    if nest.is_sequence(inputs):
707      # In the case of nested input, use the first element for shape check.
708      input_shape = K.int_shape(nest.flatten(inputs)[0])
709    else:
710      input_shape = K.int_shape(inputs)
711    timesteps = input_shape[0] if self.time_major else input_shape[1]
712    if self.unroll and timesteps is None:
713      raise ValueError('Cannot unroll a RNN if the '
714                       'time dimension is undefined. \n'
715                       '- If using a Sequential model, '
716                       'specify the time dimension by passing '
717                       'an `input_shape` or `batch_input_shape` '
718                       'argument to your first layer. If your '
719                       'first layer is an Embedding, you can '
720                       'also use the `input_length` argument.\n'
721                       '- If using the functional API, specify '
722                       'the time dimension by passing a `shape` '
723                       'or `batch_shape` argument to your Input layer.')
724
725    kwargs = {}
726    if generic_utils.has_arg(self.cell.call, 'training'):
727      kwargs['training'] = training
728
729    # TF RNN cells expect single tensor as state instead of list wrapped tensor.
730    is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
731    if constants:
732      if not generic_utils.has_arg(self.cell.call, 'constants'):
733        raise ValueError('RNN cell does not support constants')
734
735      def step(inputs, states):
736        constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
737        states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
738
739        states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
740        output, new_states = self.cell.call(
741            inputs, states, constants=constants, **kwargs)
742        if not nest.is_sequence(new_states):
743          new_states = [new_states]
744        return output, new_states
745    else:
746
747      def step(inputs, states):
748        states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
749        output, new_states = self.cell.call(inputs, states, **kwargs)
750        if not nest.is_sequence(new_states):
751          new_states = [new_states]
752        return output, new_states
753
754    last_output, outputs, states = K.rnn(
755        step,
756        inputs,
757        initial_state,
758        constants=constants,
759        go_backwards=self.go_backwards,
760        mask=mask,
761        unroll=self.unroll,
762        input_length=timesteps,
763        time_major=self.time_major,
764        zero_output_for_mask=self.zero_output_for_mask)
765    if self.stateful:
766      updates = []
767      for i in range(len(states)):
768        updates.append(state_ops.assign(self.states[i], states[i]))
769      self.add_update(updates, inputs)
770
771    if self.return_sequences:
772      output = outputs
773    else:
774      output = last_output
775
776    if self.return_state:
777      if not isinstance(states, (list, tuple)):
778        states = [states]
779      else:
780        states = list(states)
781      return generic_utils.to_list(output) + states
782    else:
783      return output
784
785  def _process_inputs(self, inputs, initial_state, constants):
786    # input shape: `(samples, time (padded with zeros), input_dim)`
787    # note that the .build() method of subclasses MUST define
788    # self.input_spec and self.state_spec with complete input shapes.
789    if (isinstance(inputs, collections.Sequence)
790        and not isinstance(inputs, tuple)):
791      # get initial_state from full input spec
792      # as they could be copied to multiple GPU.
793      if self._num_constants is None:
794        initial_state = inputs[1:]
795      else:
796        initial_state = inputs[1:-self._num_constants]
797        constants = inputs[-self._num_constants:]
798      if len(initial_state) == 0:
799        initial_state = None
800      inputs = inputs[0]
801    if initial_state is not None:
802      pass
803    elif self.stateful:
804      initial_state = self.states
805    else:
806      initial_state = self.get_initial_state(inputs)
807
808    if len(initial_state) != len(self.states):
809      raise ValueError('Layer has ' + str(len(self.states)) +
810                       ' states but was passed ' + str(len(initial_state)) +
811                       ' initial states.')
812    return inputs, initial_state, constants
813
814  def reset_states(self, states=None):
815    if not self.stateful:
816      raise AttributeError('Layer must be stateful.')
817    spec_shape = None if self.input_spec is None else self.input_spec[0].shape
818    if spec_shape is None:
819      # It is possible to have spec shape to be None, eg when construct a RNN
820      # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
821      # it has 3 dim input, but not its full shape spec before build().
822      batch_size = None
823    else:
824      batch_size = spec_shape[1] if self.time_major else spec_shape[0]
825    if not batch_size:
826      raise ValueError('If a RNN is stateful, it needs to know '
827                       'its batch size. Specify the batch size '
828                       'of your input tensors: \n'
829                       '- If using a Sequential model, '
830                       'specify the batch size by passing '
831                       'a `batch_input_shape` '
832                       'argument to your first layer.\n'
833                       '- If using the functional API, specify '
834                       'the batch size by passing a '
835                       '`batch_shape` argument to your Input layer.')
836    # initialize state if None
837    if self.states[0] is None:
838      if _is_multiple_state(self.cell.state_size):
839        self.states = [
840            K.zeros([batch_size] + tensor_shape.as_shape(dim).as_list())
841            for dim in self.cell.state_size
842        ]
843      else:
844        self.states = [
845            K.zeros([batch_size] +
846                    tensor_shape.as_shape(self.cell.state_size).as_list())
847        ]
848    elif states is None:
849      if _is_multiple_state(self.cell.state_size):
850        for state, dim in zip(self.states, self.cell.state_size):
851          K.set_value(state,
852                      np.zeros([batch_size] +
853                               tensor_shape.as_shape(dim).as_list()))
854      else:
855        K.set_value(self.states[0], np.zeros(
856            [batch_size] +
857            tensor_shape.as_shape(self.cell.state_size).as_list()))
858    else:
859      if not isinstance(states, (list, tuple)):
860        states = [states]
861      if len(states) != len(self.states):
862        raise ValueError('Layer ' + self.name + ' expects ' +
863                         str(len(self.states)) + ' states, '
864                         'but it received ' + str(len(states)) +
865                         ' state values. Input received: ' + str(states))
866      for index, (value, state) in enumerate(zip(states, self.states)):
867        if _is_multiple_state(self.cell.state_size):
868          dim = self.cell.state_size[index]
869        else:
870          dim = self.cell.state_size
871        if value.shape != tuple([batch_size] +
872                                tensor_shape.as_shape(dim).as_list()):
873          raise ValueError(
874              'State ' + str(index) + ' is incompatible with layer ' +
875              self.name + ': expected shape=' + str(
876                  (batch_size, dim)) + ', found shape=' + str(value.shape))
877        # TODO(fchollet): consider batch calls to `set_value`.
878        K.set_value(state, value)
879
880  def get_config(self):
881    config = {
882        'return_sequences': self.return_sequences,
883        'return_state': self.return_state,
884        'go_backwards': self.go_backwards,
885        'stateful': self.stateful,
886        'unroll': self.unroll,
887        'time_major': self.time_major
888    }
889    if self._num_constants is not None:
890      config['num_constants'] = self._num_constants
891    if self.zero_output_for_mask:
892      config['zero_output_for_mask'] = self.zero_output_for_mask
893
894    cell_config = self.cell.get_config()
895    config['cell'] = {
896        'class_name': self.cell.__class__.__name__,
897        'config': cell_config
898    }
899    base_config = super(RNN, self).get_config()
900    return dict(list(base_config.items()) + list(config.items()))
901
902  @classmethod
903  def from_config(cls, config, custom_objects=None):
904    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
905    cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
906    num_constants = config.pop('num_constants', None)
907    layer = cls(cell, **config)
908    layer._num_constants = num_constants
909    return layer
910
911
912@keras_export('keras.layers.AbstractRNNCell')
913class AbstractRNNCell(Layer):
914  """Abstract object representing an RNN cell.
915
916  This is the base class for implementing RNN cells with custom behavior.
917
918  Every `RNNCell` must have the properties below and implement `call` with
919  the signature `(output, next_state) = call(input, state)`.
920
921  Examples:
922
923  ```python
924    class MinimalRNNCell(AbstractRNNCell):
925
926      def __init__(self, units, **kwargs):
927        self.units = units
928        super(MinimalRNNCell, self).__init__(**kwargs)
929
930      @property
931      def state_size(self):
932        return self.units
933
934      def build(self, input_shape):
935        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
936                                      initializer='uniform',
937                                      name='kernel')
938        self.recurrent_kernel = self.add_weight(
939            shape=(self.units, self.units),
940            initializer='uniform',
941            name='recurrent_kernel')
942        self.built = True
943
944      def call(self, inputs, states):
945        prev_output = states[0]
946        h = K.dot(inputs, self.kernel)
947        output = h + K.dot(prev_output, self.recurrent_kernel)
948        return output, output
949  ```
950
951  This definition of cell differs from the definition used in the literature.
952  In the literature, 'cell' refers to an object with a single scalar output.
953  This definition refers to a horizontal array of such units.
954
955  An RNN cell, in the most abstract setting, is anything that has
956  a state and performs some operation that takes a matrix of inputs.
957  This operation results in an output matrix with `self.output_size` columns.
958  If `self.state_size` is an integer, this operation also results in a new
959  state matrix with `self.state_size` columns.  If `self.state_size` is a
960  (possibly nested tuple of) TensorShape object(s), then it should return a
961  matching structure of Tensors having shape `[batch_size].concatenate(s)`
962  for each `s` in `self.batch_size`.
963  """
964
965  def call(self, inputs, states):
966    """The function that contains the logic for one RNN step calculation.
967
968    Args:
969      inputs: the input tensor, which is a slide from the overall RNN input by
970        the time dimension (usually the second dimension).
971      states: the state tensor from previous step, which has the same shape
972        as `(batch, state_size)`. In the case of timestep 0, it will be the
973        initial state user specified, or zero filled tensor otherwise.
974
975    Returns:
976      A tuple of two tensors:
977        1. output tensor for the current timestep, with size `output_size`.
978        2. state tensor for next step, which has the shape of `state_size`.
979    """
980    raise NotImplementedError('Abstract method')
981
982  @property
983  def state_size(self):
984    """size(s) of state(s) used by this cell.
985
986    It can be represented by an Integer, a TensorShape or a tuple of Integers
987    or TensorShapes.
988    """
989    raise NotImplementedError('Abstract method')
990
991  @property
992  def output_size(self):
993    """Integer or TensorShape: size of outputs produced by this cell."""
994    raise NotImplementedError('Abstract method')
995
996  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
997    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
998
999
1000class DropoutRNNCellMixin(object):
1001  """Object that hold dropout related fields for RNN Cell.
1002
1003  This class is not a standalone RNN cell. It suppose to be used with a RNN cell
1004  by multiple inheritance. Any cell that mix with class should have following
1005  fields:
1006    dropout: a float number within range [0, 1). The ratio that the input
1007      tensor need to dropout.
1008    recurrent_dropout: a float number within range [0, 1). The ratio that the
1009      recurrent state weights need to dropout.
1010  This object will create and cache created dropout masks, and reuse them for
1011  the incoming data, so that the same mask is used for every batch input.
1012  """
1013
1014  def __init__(self, *args, **kwargs):
1015    # Note that the following two masks will be used in "graph function" mode,
1016    # e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
1017    # tensors will be generated differently than in the "graph function" case,
1018    # and they will be cached.
1019    # Also note that in graph mode, we still cache those masks only because the
1020    # RNN could be created with `unroll=True`. In that case, the `cell.call()`
1021    # function will be invoked multiple times, and we want to ensure same mask
1022    # is used every time.
1023    self._dropout_mask = None
1024    self._recurrent_dropout_mask = None
1025    self._eager_dropout_mask = None
1026    self._eager_recurrent_dropout_mask = None
1027    super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
1028
1029  def reset_dropout_mask(self):
1030    """Reset the cached dropout masks if any.
1031
1032    This is important for the RNN layer to invoke this in it call() method so
1033    that the cached mask is cleared before calling the cell.call(). The mask
1034    should be cached across the timestep within the same batch, but shouldn't
1035    be cached between batches. Otherwise it will introduce unreasonable bias
1036    against certain index of data within the batch.
1037    """
1038    self._dropout_mask = None
1039    self._eager_dropout_mask = None
1040
1041  def reset_recurrent_dropout_mask(self):
1042    """Reset the cached recurrent dropout masks if any.
1043
1044    This is important for the RNN layer to invoke this in it call() method so
1045    that the cached mask is cleared before calling the cell.call(). The mask
1046    should be cached across the timestep within the same batch, but shouldn't
1047    be cached between batches. Otherwise it will introduce unreasonable bias
1048    against certain index of data within the batch.
1049    """
1050    self._recurrent_dropout_mask = None
1051    self._eager_recurrent_dropout_mask = None
1052
1053  def get_dropout_mask_for_cell(self, inputs, training, count=1):
1054    """Get the dropout mask for RNN cell's input.
1055
1056    It will create mask based on context if there isn't any existing cached
1057    mask. If a new mask is generated, it will update the cache in the cell.
1058
1059    Args:
1060      inputs: the input tensor whose shape will be used to generate dropout
1061        mask.
1062      training: boolean tensor, whether its in training mode, dropout will be
1063        ignored in non-training mode.
1064      count: int, how many dropout mask will be generated. It is useful for cell
1065        that has internal weights fused together.
1066    Returns:
1067      List of mask tensor, generated or cached mask based on context.
1068    """
1069    if self.dropout == 0:
1070      return None
1071    if (not context.executing_eagerly() and self._dropout_mask is None
1072        or context.executing_eagerly() and self._eager_dropout_mask is None):
1073      # Generate new mask and cache it based on context.
1074      dp_mask = _generate_dropout_mask(
1075          array_ops.ones_like(inputs),
1076          self.dropout,
1077          training=training,
1078          count=count)
1079      if context.executing_eagerly():
1080        self._eager_dropout_mask = dp_mask
1081      else:
1082        self._dropout_mask = dp_mask
1083    else:
1084      # Reuse the existing mask.
1085      dp_mask = (self._eager_dropout_mask
1086                 if context.executing_eagerly() else self._dropout_mask)
1087    return dp_mask
1088
1089  def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
1090    """Get the recurrent dropout mask for RNN cell.
1091
1092    It will create mask based on context if there isn't any existing cached
1093    mask. If a new mask is generated, it will update the cache in the cell.
1094
1095    Args:
1096      inputs: the input tensor whose shape will be used to generate dropout
1097        mask.
1098      training: boolean tensor, whether its in training mode, dropout will be
1099        ignored in non-training mode.
1100      count: int, how many dropout mask will be generated. It is useful for cell
1101        that has internal weights fused together.
1102    Returns:
1103      List of mask tensor, generated or cached mask based on context.
1104    """
1105    if self.recurrent_dropout == 0:
1106      return None
1107    if (not context.executing_eagerly() and self._recurrent_dropout_mask is None
1108        or context.executing_eagerly()
1109        and self._eager_recurrent_dropout_mask is None):
1110      # Generate new mask and cache it based on context.
1111      rec_dp_mask = _generate_dropout_mask(
1112          array_ops.ones_like(inputs),
1113          self.recurrent_dropout,
1114          training=training,
1115          count=count)
1116      if context.executing_eagerly():
1117        self._eager_recurrent_dropout_mask = rec_dp_mask
1118      else:
1119        self._recurrent_dropout_mask = rec_dp_mask
1120    else:
1121      # Reuse the existing mask.
1122      rec_dp_mask = (self._eager_recurrent_dropout_mask
1123                     if context.executing_eagerly()
1124                     else self._recurrent_dropout_mask)
1125    return rec_dp_mask
1126
1127
1128@keras_export('keras.layers.SimpleRNNCell')
1129class SimpleRNNCell(DropoutRNNCellMixin, Layer):
1130  """Cell class for SimpleRNN.
1131
1132  Arguments:
1133    units: Positive integer, dimensionality of the output space.
1134    activation: Activation function to use.
1135      Default: hyperbolic tangent (`tanh`).
1136      If you pass `None`, no activation is applied
1137      (ie. "linear" activation: `a(x) = x`).
1138    use_bias: Boolean, whether the layer uses a bias vector.
1139    kernel_initializer: Initializer for the `kernel` weights matrix,
1140      used for the linear transformation of the inputs.
1141    recurrent_initializer: Initializer for the `recurrent_kernel`
1142      weights matrix, used for the linear transformation of the recurrent state.
1143    bias_initializer: Initializer for the bias vector.
1144    kernel_regularizer: Regularizer function applied to
1145      the `kernel` weights matrix.
1146    recurrent_regularizer: Regularizer function applied to
1147      the `recurrent_kernel` weights matrix.
1148    bias_regularizer: Regularizer function applied to the bias vector.
1149    kernel_constraint: Constraint function applied to
1150      the `kernel` weights matrix.
1151    recurrent_constraint: Constraint function applied to
1152      the `recurrent_kernel` weights matrix.
1153    bias_constraint: Constraint function applied to the bias vector.
1154    dropout: Float between 0 and 1.
1155      Fraction of the units to drop for
1156      the linear transformation of the inputs.
1157    recurrent_dropout: Float between 0 and 1.
1158      Fraction of the units to drop for
1159      the linear transformation of the recurrent state.
1160
1161  Call arguments:
1162    inputs: A 2D tensor.
1163    states: List of state tensors corresponding to the previous timestep.
1164    training: Python boolean indicating whether the layer should behave in
1165      training mode or in inference mode. Only relevant when `dropout` or
1166      `recurrent_dropout` is used.
1167  """
1168
1169  def __init__(self,
1170               units,
1171               activation='tanh',
1172               use_bias=True,
1173               kernel_initializer='glorot_uniform',
1174               recurrent_initializer='orthogonal',
1175               bias_initializer='zeros',
1176               kernel_regularizer=None,
1177               recurrent_regularizer=None,
1178               bias_regularizer=None,
1179               kernel_constraint=None,
1180               recurrent_constraint=None,
1181               bias_constraint=None,
1182               dropout=0.,
1183               recurrent_dropout=0.,
1184               **kwargs):
1185    super(SimpleRNNCell, self).__init__(**kwargs)
1186    self.units = units
1187    self.activation = activations.get(activation)
1188    self.use_bias = use_bias
1189
1190    self.kernel_initializer = initializers.get(kernel_initializer)
1191    self.recurrent_initializer = initializers.get(recurrent_initializer)
1192    self.bias_initializer = initializers.get(bias_initializer)
1193
1194    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1195    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1196    self.bias_regularizer = regularizers.get(bias_regularizer)
1197
1198    self.kernel_constraint = constraints.get(kernel_constraint)
1199    self.recurrent_constraint = constraints.get(recurrent_constraint)
1200    self.bias_constraint = constraints.get(bias_constraint)
1201
1202    self.dropout = min(1., max(0., dropout))
1203    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1204    self.state_size = self.units
1205    self.output_size = self.units
1206
1207  @tf_utils.shape_type_conversion
1208  def build(self, input_shape):
1209    self.kernel = self.add_weight(
1210        shape=(input_shape[-1], self.units),
1211        name='kernel',
1212        initializer=self.kernel_initializer,
1213        regularizer=self.kernel_regularizer,
1214        constraint=self.kernel_constraint)
1215    self.recurrent_kernel = self.add_weight(
1216        shape=(self.units, self.units),
1217        name='recurrent_kernel',
1218        initializer=self.recurrent_initializer,
1219        regularizer=self.recurrent_regularizer,
1220        constraint=self.recurrent_constraint)
1221    if self.use_bias:
1222      self.bias = self.add_weight(
1223          shape=(self.units,),
1224          name='bias',
1225          initializer=self.bias_initializer,
1226          regularizer=self.bias_regularizer,
1227          constraint=self.bias_constraint)
1228    else:
1229      self.bias = None
1230    self.built = True
1231
1232  def call(self, inputs, states, training=None):
1233    prev_output = states[0]
1234    dp_mask = self.get_dropout_mask_for_cell(inputs, training)
1235    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1236        prev_output, training)
1237
1238    if dp_mask is not None:
1239      h = K.dot(inputs * dp_mask, self.kernel)
1240    else:
1241      h = K.dot(inputs, self.kernel)
1242    if self.bias is not None:
1243      h = K.bias_add(h, self.bias)
1244
1245    if rec_dp_mask is not None:
1246      prev_output *= rec_dp_mask
1247    output = h + K.dot(prev_output, self.recurrent_kernel)
1248    if self.activation is not None:
1249      output = self.activation(output)
1250
1251    return output, [output]
1252
1253  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1254    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1255
1256  def get_config(self):
1257    config = {
1258        'units':
1259            self.units,
1260        'activation':
1261            activations.serialize(self.activation),
1262        'use_bias':
1263            self.use_bias,
1264        'kernel_initializer':
1265            initializers.serialize(self.kernel_initializer),
1266        'recurrent_initializer':
1267            initializers.serialize(self.recurrent_initializer),
1268        'bias_initializer':
1269            initializers.serialize(self.bias_initializer),
1270        'kernel_regularizer':
1271            regularizers.serialize(self.kernel_regularizer),
1272        'recurrent_regularizer':
1273            regularizers.serialize(self.recurrent_regularizer),
1274        'bias_regularizer':
1275            regularizers.serialize(self.bias_regularizer),
1276        'kernel_constraint':
1277            constraints.serialize(self.kernel_constraint),
1278        'recurrent_constraint':
1279            constraints.serialize(self.recurrent_constraint),
1280        'bias_constraint':
1281            constraints.serialize(self.bias_constraint),
1282        'dropout':
1283            self.dropout,
1284        'recurrent_dropout':
1285            self.recurrent_dropout
1286    }
1287    base_config = super(SimpleRNNCell, self).get_config()
1288    return dict(list(base_config.items()) + list(config.items()))
1289
1290
1291@keras_export('keras.layers.SimpleRNN')
1292class SimpleRNN(RNN):
1293  """Fully-connected RNN where the output is to be fed back to input.
1294
1295  Arguments:
1296    units: Positive integer, dimensionality of the output space.
1297    activation: Activation function to use.
1298      Default: hyperbolic tangent (`tanh`).
1299      If you pass None, no activation is applied
1300      (ie. "linear" activation: `a(x) = x`).
1301    use_bias: Boolean, whether the layer uses a bias vector.
1302    kernel_initializer: Initializer for the `kernel` weights matrix,
1303      used for the linear transformation of the inputs.
1304    recurrent_initializer: Initializer for the `recurrent_kernel`
1305      weights matrix,
1306      used for the linear transformation of the recurrent state.
1307    bias_initializer: Initializer for the bias vector.
1308    kernel_regularizer: Regularizer function applied to
1309      the `kernel` weights matrix.
1310    recurrent_regularizer: Regularizer function applied to
1311      the `recurrent_kernel` weights matrix.
1312    bias_regularizer: Regularizer function applied to the bias vector.
1313    activity_regularizer: Regularizer function applied to
1314      the output of the layer (its "activation")..
1315    kernel_constraint: Constraint function applied to
1316      the `kernel` weights matrix.
1317    recurrent_constraint: Constraint function applied to
1318      the `recurrent_kernel` weights matrix.
1319    bias_constraint: Constraint function applied to the bias vector.
1320    dropout: Float between 0 and 1.
1321      Fraction of the units to drop for
1322      the linear transformation of the inputs.
1323    recurrent_dropout: Float between 0 and 1.
1324      Fraction of the units to drop for
1325      the linear transformation of the recurrent state.
1326    return_sequences: Boolean. Whether to return the last output
1327      in the output sequence, or the full sequence.
1328    return_state: Boolean. Whether to return the last state
1329      in addition to the output.
1330    go_backwards: Boolean (default False).
1331      If True, process the input sequence backwards and return the
1332      reversed sequence.
1333    stateful: Boolean (default False). If True, the last state
1334      for each sample at index i in a batch will be used as initial
1335      state for the sample of index i in the following batch.
1336    unroll: Boolean (default False).
1337      If True, the network will be unrolled,
1338      else a symbolic loop will be used.
1339      Unrolling can speed-up a RNN,
1340      although it tends to be more memory-intensive.
1341      Unrolling is only suitable for short sequences.
1342
1343  Call arguments:
1344    inputs: A 3D tensor.
1345    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
1346      a given timestep should be masked.
1347    training: Python boolean indicating whether the layer should behave in
1348      training mode or in inference mode. This argument is passed to the cell
1349      when calling it. This is only relevant if `dropout` or
1350      `recurrent_dropout` is used.
1351    initial_state: List of initial state tensors to be passed to the first
1352      call of the cell.
1353  """
1354
1355  def __init__(self,
1356               units,
1357               activation='tanh',
1358               use_bias=True,
1359               kernel_initializer='glorot_uniform',
1360               recurrent_initializer='orthogonal',
1361               bias_initializer='zeros',
1362               kernel_regularizer=None,
1363               recurrent_regularizer=None,
1364               bias_regularizer=None,
1365               activity_regularizer=None,
1366               kernel_constraint=None,
1367               recurrent_constraint=None,
1368               bias_constraint=None,
1369               dropout=0.,
1370               recurrent_dropout=0.,
1371               return_sequences=False,
1372               return_state=False,
1373               go_backwards=False,
1374               stateful=False,
1375               unroll=False,
1376               **kwargs):
1377    if 'implementation' in kwargs:
1378      kwargs.pop('implementation')
1379      logging.warning('The `implementation` argument '
1380                      'in `SimpleRNN` has been deprecated. '
1381                      'Please remove it from your layer call.')
1382    cell = SimpleRNNCell(
1383        units,
1384        activation=activation,
1385        use_bias=use_bias,
1386        kernel_initializer=kernel_initializer,
1387        recurrent_initializer=recurrent_initializer,
1388        bias_initializer=bias_initializer,
1389        kernel_regularizer=kernel_regularizer,
1390        recurrent_regularizer=recurrent_regularizer,
1391        bias_regularizer=bias_regularizer,
1392        kernel_constraint=kernel_constraint,
1393        recurrent_constraint=recurrent_constraint,
1394        bias_constraint=bias_constraint,
1395        dropout=dropout,
1396        recurrent_dropout=recurrent_dropout)
1397    super(SimpleRNN, self).__init__(
1398        cell,
1399        return_sequences=return_sequences,
1400        return_state=return_state,
1401        go_backwards=go_backwards,
1402        stateful=stateful,
1403        unroll=unroll,
1404        **kwargs)
1405    self.activity_regularizer = regularizers.get(activity_regularizer)
1406    self.input_spec = [InputSpec(ndim=3)]
1407
1408  def call(self, inputs, mask=None, training=None, initial_state=None):
1409    self.cell.reset_dropout_mask()
1410    self.cell.reset_recurrent_dropout_mask()
1411    return super(SimpleRNN, self).call(
1412        inputs, mask=mask, training=training, initial_state=initial_state)
1413
1414  @property
1415  def units(self):
1416    return self.cell.units
1417
1418  @property
1419  def activation(self):
1420    return self.cell.activation
1421
1422  @property
1423  def use_bias(self):
1424    return self.cell.use_bias
1425
1426  @property
1427  def kernel_initializer(self):
1428    return self.cell.kernel_initializer
1429
1430  @property
1431  def recurrent_initializer(self):
1432    return self.cell.recurrent_initializer
1433
1434  @property
1435  def bias_initializer(self):
1436    return self.cell.bias_initializer
1437
1438  @property
1439  def kernel_regularizer(self):
1440    return self.cell.kernel_regularizer
1441
1442  @property
1443  def recurrent_regularizer(self):
1444    return self.cell.recurrent_regularizer
1445
1446  @property
1447  def bias_regularizer(self):
1448    return self.cell.bias_regularizer
1449
1450  @property
1451  def kernel_constraint(self):
1452    return self.cell.kernel_constraint
1453
1454  @property
1455  def recurrent_constraint(self):
1456    return self.cell.recurrent_constraint
1457
1458  @property
1459  def bias_constraint(self):
1460    return self.cell.bias_constraint
1461
1462  @property
1463  def dropout(self):
1464    return self.cell.dropout
1465
1466  @property
1467  def recurrent_dropout(self):
1468    return self.cell.recurrent_dropout
1469
1470  def get_config(self):
1471    config = {
1472        'units':
1473            self.units,
1474        'activation':
1475            activations.serialize(self.activation),
1476        'use_bias':
1477            self.use_bias,
1478        'kernel_initializer':
1479            initializers.serialize(self.kernel_initializer),
1480        'recurrent_initializer':
1481            initializers.serialize(self.recurrent_initializer),
1482        'bias_initializer':
1483            initializers.serialize(self.bias_initializer),
1484        'kernel_regularizer':
1485            regularizers.serialize(self.kernel_regularizer),
1486        'recurrent_regularizer':
1487            regularizers.serialize(self.recurrent_regularizer),
1488        'bias_regularizer':
1489            regularizers.serialize(self.bias_regularizer),
1490        'activity_regularizer':
1491            regularizers.serialize(self.activity_regularizer),
1492        'kernel_constraint':
1493            constraints.serialize(self.kernel_constraint),
1494        'recurrent_constraint':
1495            constraints.serialize(self.recurrent_constraint),
1496        'bias_constraint':
1497            constraints.serialize(self.bias_constraint),
1498        'dropout':
1499            self.dropout,
1500        'recurrent_dropout':
1501            self.recurrent_dropout
1502    }
1503    base_config = super(SimpleRNN, self).get_config()
1504    del base_config['cell']
1505    return dict(list(base_config.items()) + list(config.items()))
1506
1507  @classmethod
1508  def from_config(cls, config):
1509    if 'implementation' in config:
1510      config.pop('implementation')
1511    return cls(**config)
1512
1513
1514@keras_export('keras.layers.GRUCell')
1515class GRUCell(DropoutRNNCellMixin, Layer):
1516  """Cell class for the GRU layer.
1517
1518  Arguments:
1519    units: Positive integer, dimensionality of the output space.
1520    activation: Activation function to use.
1521      Default: hyperbolic tangent (`tanh`).
1522      If you pass None, no activation is applied
1523      (ie. "linear" activation: `a(x) = x`).
1524    recurrent_activation: Activation function to use
1525      for the recurrent step.
1526      Default: hard sigmoid (`hard_sigmoid`).
1527      If you pass `None`, no activation is applied
1528      (ie. "linear" activation: `a(x) = x`).
1529    use_bias: Boolean, whether the layer uses a bias vector.
1530    kernel_initializer: Initializer for the `kernel` weights matrix,
1531      used for the linear transformation of the inputs.
1532    recurrent_initializer: Initializer for the `recurrent_kernel`
1533      weights matrix,
1534      used for the linear transformation of the recurrent state.
1535    bias_initializer: Initializer for the bias vector.
1536    kernel_regularizer: Regularizer function applied to
1537      the `kernel` weights matrix.
1538    recurrent_regularizer: Regularizer function applied to
1539      the `recurrent_kernel` weights matrix.
1540    bias_regularizer: Regularizer function applied to the bias vector.
1541    kernel_constraint: Constraint function applied to
1542      the `kernel` weights matrix.
1543    recurrent_constraint: Constraint function applied to
1544      the `recurrent_kernel` weights matrix.
1545    bias_constraint: Constraint function applied to the bias vector.
1546    dropout: Float between 0 and 1.
1547      Fraction of the units to drop for the linear transformation of the inputs.
1548    recurrent_dropout: Float between 0 and 1.
1549      Fraction of the units to drop for
1550      the linear transformation of the recurrent state.
1551    implementation: Implementation mode, either 1 or 2.
1552      Mode 1 will structure its operations as a larger number of
1553      smaller dot products and additions, whereas mode 2 will
1554      batch them into fewer, larger operations. These modes will
1555      have different performance profiles on different hardware and
1556      for different applications.
1557    reset_after: GRU convention (whether to apply reset gate after or
1558      before matrix multiplication). False = "before" (default),
1559      True = "after" (CuDNN compatible).
1560
1561  Call arguments:
1562    inputs: A 2D tensor.
1563    states: List of state tensors corresponding to the previous timestep.
1564    training: Python boolean indicating whether the layer should behave in
1565      training mode or in inference mode. Only relevant when `dropout` or
1566      `recurrent_dropout` is used.
1567  """
1568
1569  def __init__(self,
1570               units,
1571               activation='tanh',
1572               recurrent_activation='hard_sigmoid',
1573               use_bias=True,
1574               kernel_initializer='glorot_uniform',
1575               recurrent_initializer='orthogonal',
1576               bias_initializer='zeros',
1577               kernel_regularizer=None,
1578               recurrent_regularizer=None,
1579               bias_regularizer=None,
1580               kernel_constraint=None,
1581               recurrent_constraint=None,
1582               bias_constraint=None,
1583               dropout=0.,
1584               recurrent_dropout=0.,
1585               implementation=1,
1586               reset_after=False,
1587               **kwargs):
1588    super(GRUCell, self).__init__(**kwargs)
1589    self.units = units
1590    self.activation = activations.get(activation)
1591    self.recurrent_activation = activations.get(recurrent_activation)
1592    self.use_bias = use_bias
1593
1594    self.kernel_initializer = initializers.get(kernel_initializer)
1595    self.recurrent_initializer = initializers.get(recurrent_initializer)
1596    self.bias_initializer = initializers.get(bias_initializer)
1597
1598    self.kernel_regularizer = regularizers.get(kernel_regularizer)
1599    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
1600    self.bias_regularizer = regularizers.get(bias_regularizer)
1601
1602    self.kernel_constraint = constraints.get(kernel_constraint)
1603    self.recurrent_constraint = constraints.get(recurrent_constraint)
1604    self.bias_constraint = constraints.get(bias_constraint)
1605
1606    self.dropout = min(1., max(0., dropout))
1607    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
1608    self.implementation = implementation
1609    self.reset_after = reset_after
1610    self.state_size = self.units
1611    self.output_size = self.units
1612
1613  @tf_utils.shape_type_conversion
1614  def build(self, input_shape):
1615    input_dim = input_shape[-1]
1616    self.kernel = self.add_weight(
1617        shape=(input_dim, self.units * 3),
1618        name='kernel',
1619        initializer=self.kernel_initializer,
1620        regularizer=self.kernel_regularizer,
1621        constraint=self.kernel_constraint)
1622    self.recurrent_kernel = self.add_weight(
1623        shape=(self.units, self.units * 3),
1624        name='recurrent_kernel',
1625        initializer=self.recurrent_initializer,
1626        regularizer=self.recurrent_regularizer,
1627        constraint=self.recurrent_constraint)
1628
1629    if self.use_bias:
1630      if not self.reset_after:
1631        bias_shape = (3 * self.units,)
1632      else:
1633        # separate biases for input and recurrent kernels
1634        # Note: the shape is intentionally different from CuDNNGRU biases
1635        # `(2 * 3 * self.units,)`, so that we can distinguish the classes
1636        # when loading and converting saved weights.
1637        bias_shape = (2, 3 * self.units)
1638      self.bias = self.add_weight(shape=bias_shape,
1639                                  name='bias',
1640                                  initializer=self.bias_initializer,
1641                                  regularizer=self.bias_regularizer,
1642                                  constraint=self.bias_constraint)
1643    else:
1644      self.bias = None
1645    self.built = True
1646
1647  def call(self, inputs, states, training=None):
1648    h_tm1 = states[0]  # previous memory
1649
1650    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
1651    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
1652        h_tm1, training, count=3)
1653
1654    if self.use_bias:
1655      if not self.reset_after:
1656        input_bias, recurrent_bias = self.bias, None
1657      else:
1658        input_bias, recurrent_bias = array_ops.unstack(self.bias)
1659
1660    if self.implementation == 1:
1661      if 0. < self.dropout < 1.:
1662        inputs_z = inputs * dp_mask[0]
1663        inputs_r = inputs * dp_mask[1]
1664        inputs_h = inputs * dp_mask[2]
1665      else:
1666        inputs_z = inputs
1667        inputs_r = inputs
1668        inputs_h = inputs
1669
1670      x_z = K.dot(inputs_z, self.kernel[:, :self.units])
1671      x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
1672      x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])
1673
1674      if self.use_bias:
1675        x_z = K.bias_add(x_z, input_bias[:self.units])
1676        x_r = K.bias_add(x_r, input_bias[self.units: self.units * 2])
1677        x_h = K.bias_add(x_h, input_bias[self.units * 2:])
1678
1679      if 0. < self.recurrent_dropout < 1.:
1680        h_tm1_z = h_tm1 * rec_dp_mask[0]
1681        h_tm1_r = h_tm1 * rec_dp_mask[1]
1682        h_tm1_h = h_tm1 * rec_dp_mask[2]
1683      else:
1684        h_tm1_z = h_tm1
1685        h_tm1_r = h_tm1
1686        h_tm1_h = h_tm1
1687
1688      recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
1689      recurrent_r = K.dot(h_tm1_r,
1690                          self.recurrent_kernel[:, self.units:self.units * 2])
1691      if self.reset_after and self.use_bias:
1692        recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units])
1693        recurrent_r = K.bias_add(recurrent_r,
1694                                 recurrent_bias[self.units:self.units * 2])
1695
1696      z = self.recurrent_activation(x_z + recurrent_z)
1697      r = self.recurrent_activation(x_r + recurrent_r)
1698
1699      # reset gate applied after/before matrix multiplication
1700      if self.reset_after:
1701        recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
1702        if self.use_bias:
1703          recurrent_h = K.bias_add(recurrent_h, recurrent_bias[self.units * 2:])
1704        recurrent_h = r * recurrent_h
1705      else:
1706        recurrent_h = K.dot(r * h_tm1_h,
1707                            self.recurrent_kernel[:, self.units * 2:])
1708
1709      hh = self.activation(x_h + recurrent_h)
1710    else:
1711      if 0. < self.dropout < 1.:
1712        inputs *= dp_mask[0]
1713
1714      # inputs projected by all gate matrices at once
1715      matrix_x = K.dot(inputs, self.kernel)
1716      if self.use_bias:
1717        # biases: bias_z_i, bias_r_i, bias_h_i
1718        matrix_x = K.bias_add(matrix_x, input_bias)
1719
1720      x_z = matrix_x[:, :self.units]
1721      x_r = matrix_x[:, self.units: 2 * self.units]
1722      x_h = matrix_x[:, 2 * self.units:]
1723
1724      if 0. < self.recurrent_dropout < 1.:
1725        h_tm1 *= rec_dp_mask[0]
1726
1727      if self.reset_after:
1728        # hidden state projected by all gate matrices at once
1729        matrix_inner = K.dot(h_tm1, self.recurrent_kernel)
1730        if self.use_bias:
1731          matrix_inner = K.bias_add(matrix_inner, recurrent_bias)
1732      else:
1733        # hidden state projected separately for update/reset and new
1734        matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
1735
1736      recurrent_z = matrix_inner[:, :self.units]
1737      recurrent_r = matrix_inner[:, self.units:2 * self.units]
1738
1739      z = self.recurrent_activation(x_z + recurrent_z)
1740      r = self.recurrent_activation(x_r + recurrent_r)
1741
1742      if self.reset_after:
1743        recurrent_h = r * matrix_inner[:, 2 * self.units:]
1744      else:
1745        recurrent_h = K.dot(r * h_tm1,
1746                            self.recurrent_kernel[:, 2 * self.units:])
1747
1748      hh = self.activation(x_h + recurrent_h)
1749    # previous and candidate state mixed by update gate
1750    h = z * h_tm1 + (1 - z) * hh
1751    return h, [h]
1752
1753  def get_config(self):
1754    config = {
1755        'units': self.units,
1756        'activation': activations.serialize(self.activation),
1757        'recurrent_activation':
1758            activations.serialize(self.recurrent_activation),
1759        'use_bias': self.use_bias,
1760        'kernel_initializer': initializers.serialize(self.kernel_initializer),
1761        'recurrent_initializer':
1762            initializers.serialize(self.recurrent_initializer),
1763        'bias_initializer': initializers.serialize(self.bias_initializer),
1764        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
1765        'recurrent_regularizer':
1766            regularizers.serialize(self.recurrent_regularizer),
1767        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
1768        'kernel_constraint': constraints.serialize(self.kernel_constraint),
1769        'recurrent_constraint':
1770            constraints.serialize(self.recurrent_constraint),
1771        'bias_constraint': constraints.serialize(self.bias_constraint),
1772        'dropout': self.dropout,
1773        'recurrent_dropout': self.recurrent_dropout,
1774        'implementation': self.implementation,
1775        'reset_after': self.reset_after
1776    }
1777    base_config = super(GRUCell, self).get_config()
1778    return dict(list(base_config.items()) + list(config.items()))
1779
1780  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
1781    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
1782
1783
1784@keras_export(v1=['keras.layers.GRU'])
1785class GRU(RNN):
1786  """Gated Recurrent Unit - Cho et al. 2014.
1787
1788  There are two variants. The default one is based on 1406.1078v3 and
1789  has reset gate applied to hidden state before matrix multiplication. The
1790  other one is based on original 1406.1078v1 and has the order reversed.
1791
1792  The second variant is compatible with CuDNNGRU (GPU-only) and allows
1793  inference on CPU. Thus it has separate biases for `kernel` and
1794  `recurrent_kernel`. Use `'reset_after'=True` and
1795  `recurrent_activation='sigmoid'`.
1796
1797  Arguments:
1798    units: Positive integer, dimensionality of the output space.
1799    activation: Activation function to use.
1800      Default: hyperbolic tangent (`tanh`).
1801      If you pass `None`, no activation is applied
1802      (ie. "linear" activation: `a(x) = x`).
1803    recurrent_activation: Activation function to use
1804      for the recurrent step.
1805      Default: hard sigmoid (`hard_sigmoid`).
1806      If you pass `None`, no activation is applied
1807      (ie. "linear" activation: `a(x) = x`).
1808    use_bias: Boolean, whether the layer uses a bias vector.
1809    kernel_initializer: Initializer for the `kernel` weights matrix,
1810      used for the linear transformation of the inputs.
1811    recurrent_initializer: Initializer for the `recurrent_kernel`
1812      weights matrix, used for the linear transformation of the recurrent state.
1813    bias_initializer: Initializer for the bias vector.
1814    kernel_regularizer: Regularizer function applied to
1815      the `kernel` weights matrix.
1816    recurrent_regularizer: Regularizer function applied to
1817      the `recurrent_kernel` weights matrix.
1818    bias_regularizer: Regularizer function applied to the bias vector.
1819    activity_regularizer: Regularizer function applied to
1820      the output of the layer (its "activation")..
1821    kernel_constraint: Constraint function applied to
1822      the `kernel` weights matrix.
1823    recurrent_constraint: Constraint function applied to
1824      the `recurrent_kernel` weights matrix.
1825    bias_constraint: Constraint function applied to the bias vector.
1826    dropout: Float between 0 and 1.
1827      Fraction of the units to drop for
1828      the linear transformation of the inputs.
1829    recurrent_dropout: Float between 0 and 1.
1830      Fraction of the units to drop for
1831      the linear transformation of the recurrent state.
1832    implementation: Implementation mode, either 1 or 2.
1833      Mode 1 will structure its operations as a larger number of
1834      smaller dot products and additions, whereas mode 2 will
1835      batch them into fewer, larger operations. These modes will
1836      have different performance profiles on different hardware and
1837      for different applications.
1838    return_sequences: Boolean. Whether to return the last output
1839      in the output sequence, or the full sequence.
1840    return_state: Boolean. Whether to return the last state
1841      in addition to the output.
1842    go_backwards: Boolean (default False).
1843      If True, process the input sequence backwards and return the
1844      reversed sequence.
1845    stateful: Boolean (default False). If True, the last state
1846      for each sample at index i in a batch will be used as initial
1847      state for the sample of index i in the following batch.
1848    unroll: Boolean (default False).
1849      If True, the network will be unrolled,
1850      else a symbolic loop will be used.
1851      Unrolling can speed-up a RNN,
1852      although it tends to be more memory-intensive.
1853      Unrolling is only suitable for short sequences.
1854    reset_after: GRU convention (whether to apply reset gate after or
1855      before matrix multiplication). False = "before" (default),
1856      True = "after" (CuDNN compatible).
1857
1858  Call arguments:
1859    inputs: A 3D tensor.
1860    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
1861      a given timestep should be masked.
1862    training: Python boolean indicating whether the layer should behave in
1863      training mode or in inference mode. This argument is passed to the cell
1864      when calling it. This is only relevant if `dropout` or
1865      `recurrent_dropout` is used.
1866    initial_state: List of initial state tensors to be passed to the first
1867      call of the cell.
1868  """
1869
1870  def __init__(self,
1871               units,
1872               activation='tanh',
1873               recurrent_activation='hard_sigmoid',
1874               use_bias=True,
1875               kernel_initializer='glorot_uniform',
1876               recurrent_initializer='orthogonal',
1877               bias_initializer='zeros',
1878               kernel_regularizer=None,
1879               recurrent_regularizer=None,
1880               bias_regularizer=None,
1881               activity_regularizer=None,
1882               kernel_constraint=None,
1883               recurrent_constraint=None,
1884               bias_constraint=None,
1885               dropout=0.,
1886               recurrent_dropout=0.,
1887               implementation=1,
1888               return_sequences=False,
1889               return_state=False,
1890               go_backwards=False,
1891               stateful=False,
1892               unroll=False,
1893               reset_after=False,
1894               **kwargs):
1895    if implementation == 0:
1896      logging.warning('`implementation=0` has been deprecated, '
1897                      'and now defaults to `implementation=1`.'
1898                      'Please update your layer call.')
1899    cell = GRUCell(
1900        units,
1901        activation=activation,
1902        recurrent_activation=recurrent_activation,
1903        use_bias=use_bias,
1904        kernel_initializer=kernel_initializer,
1905        recurrent_initializer=recurrent_initializer,
1906        bias_initializer=bias_initializer,
1907        kernel_regularizer=kernel_regularizer,
1908        recurrent_regularizer=recurrent_regularizer,
1909        bias_regularizer=bias_regularizer,
1910        kernel_constraint=kernel_constraint,
1911        recurrent_constraint=recurrent_constraint,
1912        bias_constraint=bias_constraint,
1913        dropout=dropout,
1914        recurrent_dropout=recurrent_dropout,
1915        implementation=implementation,
1916        reset_after=reset_after)
1917    super(GRU, self).__init__(
1918        cell,
1919        return_sequences=return_sequences,
1920        return_state=return_state,
1921        go_backwards=go_backwards,
1922        stateful=stateful,
1923        unroll=unroll,
1924        **kwargs)
1925    self.activity_regularizer = regularizers.get(activity_regularizer)
1926    self.input_spec = [InputSpec(ndim=3)]
1927
1928  def call(self, inputs, mask=None, training=None, initial_state=None):
1929    self.cell.reset_dropout_mask()
1930    self.cell.reset_recurrent_dropout_mask()
1931    return super(GRU, self).call(
1932        inputs, mask=mask, training=training, initial_state=initial_state)
1933
1934  @property
1935  def units(self):
1936    return self.cell.units
1937
1938  @property
1939  def activation(self):
1940    return self.cell.activation
1941
1942  @property
1943  def recurrent_activation(self):
1944    return self.cell.recurrent_activation
1945
1946  @property
1947  def use_bias(self):
1948    return self.cell.use_bias
1949
1950  @property
1951  def kernel_initializer(self):
1952    return self.cell.kernel_initializer
1953
1954  @property
1955  def recurrent_initializer(self):
1956    return self.cell.recurrent_initializer
1957
1958  @property
1959  def bias_initializer(self):
1960    return self.cell.bias_initializer
1961
1962  @property
1963  def kernel_regularizer(self):
1964    return self.cell.kernel_regularizer
1965
1966  @property
1967  def recurrent_regularizer(self):
1968    return self.cell.recurrent_regularizer
1969
1970  @property
1971  def bias_regularizer(self):
1972    return self.cell.bias_regularizer
1973
1974  @property
1975  def kernel_constraint(self):
1976    return self.cell.kernel_constraint
1977
1978  @property
1979  def recurrent_constraint(self):
1980    return self.cell.recurrent_constraint
1981
1982  @property
1983  def bias_constraint(self):
1984    return self.cell.bias_constraint
1985
1986  @property
1987  def dropout(self):
1988    return self.cell.dropout
1989
1990  @property
1991  def recurrent_dropout(self):
1992    return self.cell.recurrent_dropout
1993
1994  @property
1995  def implementation(self):
1996    return self.cell.implementation
1997
1998  @property
1999  def reset_after(self):
2000    return self.cell.reset_after
2001
2002  def get_config(self):
2003    config = {
2004        'units':
2005            self.units,
2006        'activation':
2007            activations.serialize(self.activation),
2008        'recurrent_activation':
2009            activations.serialize(self.recurrent_activation),
2010        'use_bias':
2011            self.use_bias,
2012        'kernel_initializer':
2013            initializers.serialize(self.kernel_initializer),
2014        'recurrent_initializer':
2015            initializers.serialize(self.recurrent_initializer),
2016        'bias_initializer':
2017            initializers.serialize(self.bias_initializer),
2018        'kernel_regularizer':
2019            regularizers.serialize(self.kernel_regularizer),
2020        'recurrent_regularizer':
2021            regularizers.serialize(self.recurrent_regularizer),
2022        'bias_regularizer':
2023            regularizers.serialize(self.bias_regularizer),
2024        'activity_regularizer':
2025            regularizers.serialize(self.activity_regularizer),
2026        'kernel_constraint':
2027            constraints.serialize(self.kernel_constraint),
2028        'recurrent_constraint':
2029            constraints.serialize(self.recurrent_constraint),
2030        'bias_constraint':
2031            constraints.serialize(self.bias_constraint),
2032        'dropout':
2033            self.dropout,
2034        'recurrent_dropout':
2035            self.recurrent_dropout,
2036        'implementation':
2037            self.implementation,
2038        'reset_after':
2039            self.reset_after
2040    }
2041    base_config = super(GRU, self).get_config()
2042    del base_config['cell']
2043    return dict(list(base_config.items()) + list(config.items()))
2044
2045  @classmethod
2046  def from_config(cls, config):
2047    if 'implementation' in config and config['implementation'] == 0:
2048      config['implementation'] = 1
2049    return cls(**config)
2050
2051
2052@keras_export('keras.layers.LSTMCell')
2053class LSTMCell(DropoutRNNCellMixin, Layer):
2054  """Cell class for the LSTM layer.
2055
2056  Arguments:
2057    units: Positive integer, dimensionality of the output space.
2058    activation: Activation function to use.
2059      Default: hyperbolic tangent (`tanh`).
2060      If you pass `None`, no activation is applied
2061      (ie. "linear" activation: `a(x) = x`).
2062    recurrent_activation: Activation function to use
2063      for the recurrent step.
2064      Default: hard sigmoid (`hard_sigmoid`).
2065      If you pass `None`, no activation is applied
2066      (ie. "linear" activation: `a(x) = x`).
2067    use_bias: Boolean, whether the layer uses a bias vector.
2068    kernel_initializer: Initializer for the `kernel` weights matrix,
2069      used for the linear transformation of the inputs.
2070    recurrent_initializer: Initializer for the `recurrent_kernel`
2071      weights matrix,
2072      used for the linear transformation of the recurrent state.
2073    bias_initializer: Initializer for the bias vector.
2074    unit_forget_bias: Boolean.
2075      If True, add 1 to the bias of the forget gate at initialization.
2076      Setting it to true will also force `bias_initializer="zeros"`.
2077      This is recommended in [Jozefowicz et
2078        al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
2079    kernel_regularizer: Regularizer function applied to
2080      the `kernel` weights matrix.
2081    recurrent_regularizer: Regularizer function applied to
2082      the `recurrent_kernel` weights matrix.
2083    bias_regularizer: Regularizer function applied to the bias vector.
2084    kernel_constraint: Constraint function applied to
2085      the `kernel` weights matrix.
2086    recurrent_constraint: Constraint function applied to
2087      the `recurrent_kernel` weights matrix.
2088    bias_constraint: Constraint function applied to the bias vector.
2089    dropout: Float between 0 and 1.
2090      Fraction of the units to drop for
2091      the linear transformation of the inputs.
2092    recurrent_dropout: Float between 0 and 1.
2093      Fraction of the units to drop for
2094      the linear transformation of the recurrent state.
2095    implementation: Implementation mode, either 1 or 2.
2096      Mode 1 will structure its operations as a larger number of
2097      smaller dot products and additions, whereas mode 2 will
2098      batch them into fewer, larger operations. These modes will
2099      have different performance profiles on different hardware and
2100      for different applications.
2101
2102  Call arguments:
2103    inputs: A 2D tensor.
2104    states: List of state tensors corresponding to the previous timestep.
2105    training: Python boolean indicating whether the layer should behave in
2106      training mode or in inference mode. Only relevant when `dropout` or
2107      `recurrent_dropout` is used.
2108  """
2109
2110  def __init__(self,
2111               units,
2112               activation='tanh',
2113               recurrent_activation='hard_sigmoid',
2114               use_bias=True,
2115               kernel_initializer='glorot_uniform',
2116               recurrent_initializer='orthogonal',
2117               bias_initializer='zeros',
2118               unit_forget_bias=True,
2119               kernel_regularizer=None,
2120               recurrent_regularizer=None,
2121               bias_regularizer=None,
2122               kernel_constraint=None,
2123               recurrent_constraint=None,
2124               bias_constraint=None,
2125               dropout=0.,
2126               recurrent_dropout=0.,
2127               implementation=1,
2128               **kwargs):
2129    super(LSTMCell, self).__init__(**kwargs)
2130    self.units = units
2131    self.activation = activations.get(activation)
2132    self.recurrent_activation = activations.get(recurrent_activation)
2133    self.use_bias = use_bias
2134
2135    self.kernel_initializer = initializers.get(kernel_initializer)
2136    self.recurrent_initializer = initializers.get(recurrent_initializer)
2137    self.bias_initializer = initializers.get(bias_initializer)
2138    self.unit_forget_bias = unit_forget_bias
2139
2140    self.kernel_regularizer = regularizers.get(kernel_regularizer)
2141    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
2142    self.bias_regularizer = regularizers.get(bias_regularizer)
2143
2144    self.kernel_constraint = constraints.get(kernel_constraint)
2145    self.recurrent_constraint = constraints.get(recurrent_constraint)
2146    self.bias_constraint = constraints.get(bias_constraint)
2147
2148    self.dropout = min(1., max(0., dropout))
2149    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
2150    self.implementation = implementation
2151    self.state_size = [self.units, self.units]
2152    self.output_size = self.units
2153
2154  @tf_utils.shape_type_conversion
2155  def build(self, input_shape):
2156    input_dim = input_shape[-1]
2157    self.kernel = self.add_weight(
2158        shape=(input_dim, self.units * 4),
2159        name='kernel',
2160        initializer=self.kernel_initializer,
2161        regularizer=self.kernel_regularizer,
2162        constraint=self.kernel_constraint)
2163    self.recurrent_kernel = self.add_weight(
2164        shape=(self.units, self.units * 4),
2165        name='recurrent_kernel',
2166        initializer=self.recurrent_initializer,
2167        regularizer=self.recurrent_regularizer,
2168        constraint=self.recurrent_constraint)
2169
2170    if self.use_bias:
2171      if self.unit_forget_bias:
2172
2173        def bias_initializer(_, *args, **kwargs):
2174          return K.concatenate([
2175              self.bias_initializer((self.units,), *args, **kwargs),
2176              initializers.Ones()((self.units,), *args, **kwargs),
2177              self.bias_initializer((self.units * 2,), *args, **kwargs),
2178          ])
2179      else:
2180        bias_initializer = self.bias_initializer
2181      self.bias = self.add_weight(
2182          shape=(self.units * 4,),
2183          name='bias',
2184          initializer=bias_initializer,
2185          regularizer=self.bias_regularizer,
2186          constraint=self.bias_constraint)
2187    else:
2188      self.bias = None
2189    self.built = True
2190
2191  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2192    """Computes carry and output using split kernels."""
2193    x_i, x_f, x_c, x_o = x
2194    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2195    i = self.recurrent_activation(
2196        x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
2197    f = self.recurrent_activation(x_f + K.dot(
2198        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
2199    c = f * c_tm1 + i * self.activation(x_c + K.dot(
2200        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2201    o = self.recurrent_activation(
2202        x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
2203    return c, o
2204
2205  def _compute_carry_and_output_fused(self, z, c_tm1):
2206    """Computes carry and output using fused kernels."""
2207    z0, z1, z2, z3 = z
2208    i = self.recurrent_activation(z0)
2209    f = self.recurrent_activation(z1)
2210    c = f * c_tm1 + i * self.activation(z2)
2211    o = self.recurrent_activation(z3)
2212    return c, o
2213
2214  def call(self, inputs, states, training=None):
2215    h_tm1 = states[0]  # previous memory state
2216    c_tm1 = states[1]  # previous carry state
2217
2218    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
2219    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
2220        h_tm1, training, count=4)
2221
2222    if self.implementation == 1:
2223      if 0 < self.dropout < 1.:
2224        inputs_i = inputs * dp_mask[0]
2225        inputs_f = inputs * dp_mask[1]
2226        inputs_c = inputs * dp_mask[2]
2227        inputs_o = inputs * dp_mask[3]
2228      else:
2229        inputs_i = inputs
2230        inputs_f = inputs
2231        inputs_c = inputs
2232        inputs_o = inputs
2233      k_i, k_f, k_c, k_o = array_ops.split(
2234          self.kernel, num_or_size_splits=4, axis=1)
2235      x_i = K.dot(inputs_i, k_i)
2236      x_f = K.dot(inputs_f, k_f)
2237      x_c = K.dot(inputs_c, k_c)
2238      x_o = K.dot(inputs_o, k_o)
2239      if self.use_bias:
2240        b_i, b_f, b_c, b_o = array_ops.split(
2241            self.bias, num_or_size_splits=4, axis=0)
2242        x_i = K.bias_add(x_i, b_i)
2243        x_f = K.bias_add(x_f, b_f)
2244        x_c = K.bias_add(x_c, b_c)
2245        x_o = K.bias_add(x_o, b_o)
2246
2247      if 0 < self.recurrent_dropout < 1.:
2248        h_tm1_i = h_tm1 * rec_dp_mask[0]
2249        h_tm1_f = h_tm1 * rec_dp_mask[1]
2250        h_tm1_c = h_tm1 * rec_dp_mask[2]
2251        h_tm1_o = h_tm1 * rec_dp_mask[3]
2252      else:
2253        h_tm1_i = h_tm1
2254        h_tm1_f = h_tm1
2255        h_tm1_c = h_tm1
2256        h_tm1_o = h_tm1
2257      x = (x_i, x_f, x_c, x_o)
2258      h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
2259      c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
2260    else:
2261      if 0. < self.dropout < 1.:
2262        inputs *= dp_mask[0]
2263      z = K.dot(inputs, self.kernel)
2264      if 0. < self.recurrent_dropout < 1.:
2265        h_tm1 *= rec_dp_mask[0]
2266      z += K.dot(h_tm1, self.recurrent_kernel)
2267      if self.use_bias:
2268        z = K.bias_add(z, self.bias)
2269
2270      z = array_ops.split(z, num_or_size_splits=4, axis=1)
2271      c, o = self._compute_carry_and_output_fused(z, c_tm1)
2272
2273    h = o * self.activation(c)
2274    return h, [h, c]
2275
2276  def get_config(self):
2277    config = {
2278        'units':
2279            self.units,
2280        'activation':
2281            activations.serialize(self.activation),
2282        'recurrent_activation':
2283            activations.serialize(self.recurrent_activation),
2284        'use_bias':
2285            self.use_bias,
2286        'kernel_initializer':
2287            initializers.serialize(self.kernel_initializer),
2288        'recurrent_initializer':
2289            initializers.serialize(self.recurrent_initializer),
2290        'bias_initializer':
2291            initializers.serialize(self.bias_initializer),
2292        'unit_forget_bias':
2293            self.unit_forget_bias,
2294        'kernel_regularizer':
2295            regularizers.serialize(self.kernel_regularizer),
2296        'recurrent_regularizer':
2297            regularizers.serialize(self.recurrent_regularizer),
2298        'bias_regularizer':
2299            regularizers.serialize(self.bias_regularizer),
2300        'kernel_constraint':
2301            constraints.serialize(self.kernel_constraint),
2302        'recurrent_constraint':
2303            constraints.serialize(self.recurrent_constraint),
2304        'bias_constraint':
2305            constraints.serialize(self.bias_constraint),
2306        'dropout':
2307            self.dropout,
2308        'recurrent_dropout':
2309            self.recurrent_dropout,
2310        'implementation':
2311            self.implementation
2312    }
2313    base_config = super(LSTMCell, self).get_config()
2314    return dict(list(base_config.items()) + list(config.items()))
2315
2316  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
2317    return list(_generate_zero_filled_state_for_cell(
2318        self, inputs, batch_size, dtype))
2319
2320
2321@keras_export('keras.experimental.PeepholeLSTMCell')
2322class PeepholeLSTMCell(LSTMCell):
2323  """Equivalent to LSTMCell class but adds peephole connections.
2324
2325  Peephole connections allow the gates to utilize the previous internal state as
2326  well as the previous hidden state (which is what LSTMCell is limited to).
2327  This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
2328
2329  From [Gers et al.](http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):
2330
2331  "We find that LSTM augmented by 'peephole connections' from its internal
2332  cells to its multiplicative gates can learn the fine distinction between
2333  sequences of spikes spaced either 50 or 49 time steps apart without the help
2334  of any short training exemplars."
2335
2336  The peephole implementation is based on:
2337
2338  [Long short-term memory recurrent neural network architectures for
2339   large scale acoustic modeling.
2340  ](https://research.google.com/pubs/archive/43905.pdf)
2341
2342  Example:
2343
2344  ```python
2345  # Create 2 PeepholeLSTMCells
2346  peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
2347  # Create a layer composed sequentially of the peephole LSTM cells.
2348  layer = RNN(peephole_lstm_cells)
2349  input = keras.Input((timesteps, input_dim))
2350  output = layer(input)
2351  ```
2352  """
2353
2354  def build(self, input_shape):
2355    super(PeepholeLSTMCell, self).build(input_shape)
2356    # The following are the weight matrices for the peephole connections. These
2357    # are multiplied with the previous internal state during the computation of
2358    # carry and output.
2359    self.input_gate_peephole_weights = self.add_weight(
2360        shape=(self.units,),
2361        name='input_gate_peephole_weights',
2362        initializer=self.kernel_initializer)
2363    self.forget_gate_peephole_weights = self.add_weight(
2364        shape=(self.units,),
2365        name='forget_gate_peephole_weights',
2366        initializer=self.kernel_initializer)
2367    self.output_gate_peephole_weights = self.add_weight(
2368        shape=(self.units,),
2369        name='output_gate_peephole_weights',
2370        initializer=self.kernel_initializer)
2371
2372  def _compute_carry_and_output(self, x, h_tm1, c_tm1):
2373    x_i, x_f, x_c, x_o = x
2374    h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
2375    i = self.recurrent_activation(
2376        x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
2377        self.input_gate_peephole_weights * c_tm1)
2378    f = self.recurrent_activation(x_f + K.dot(
2379        h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
2380                                  self.forget_gate_peephole_weights * c_tm1)
2381    c = f * c_tm1 + i * self.activation(x_c + K.dot(
2382        h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
2383    o = self.recurrent_activation(
2384        x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
2385        self.output_gate_peephole_weights * c)
2386    return c, o
2387
2388  def _compute_carry_and_output_fused(self, z, c_tm1):
2389    z0, z1, z2, z3 = z
2390    i = self.recurrent_activation(z0 +
2391                                  self.input_gate_peephole_weights * c_tm1)
2392    f = self.recurrent_activation(z1 +
2393                                  self.forget_gate_peephole_weights * c_tm1)
2394    c = f * c_tm1 + i * self.activation(z2)
2395    o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
2396    return c, o
2397
2398
2399@keras_export(v1=['keras.layers.LSTM'])
2400class LSTM(RNN):
2401  """Long Short-Term Memory layer - Hochreiter 1997.
2402
2403   Note that this cell is not optimized for performance on GPU. Please use
2404  `tf.keras.layers.CuDNNLSTM` for better performance on GPU.
2405
2406  Arguments:
2407    units: Positive integer, dimensionality of the output space.
2408    activation: Activation function to use.
2409      Default: hyperbolic tangent (`tanh`).
2410      If you pass `None`, no activation is applied
2411      (ie. "linear" activation: `a(x) = x`).
2412    recurrent_activation: Activation function to use
2413      for the recurrent step.
2414      Default: hard sigmoid (`hard_sigmoid`).
2415      If you pass `None`, no activation is applied
2416      (ie. "linear" activation: `a(x) = x`).
2417    use_bias: Boolean, whether the layer uses a bias vector.
2418    kernel_initializer: Initializer for the `kernel` weights matrix,
2419      used for the linear transformation of the inputs..
2420    recurrent_initializer: Initializer for the `recurrent_kernel`
2421      weights matrix,
2422      used for the linear transformation of the recurrent state.
2423    bias_initializer: Initializer for the bias vector.
2424    unit_forget_bias: Boolean.
2425      If True, add 1 to the bias of the forget gate at initialization.
2426      Setting it to true will also force `bias_initializer="zeros"`.
2427      This is recommended in [Jozefowicz et
2428        al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
2429    kernel_regularizer: Regularizer function applied to
2430      the `kernel` weights matrix.
2431    recurrent_regularizer: Regularizer function applied to
2432      the `recurrent_kernel` weights matrix.
2433    bias_regularizer: Regularizer function applied to the bias vector.
2434    activity_regularizer: Regularizer function applied to
2435      the output of the layer (its "activation")..
2436    kernel_constraint: Constraint function applied to
2437      the `kernel` weights matrix.
2438    recurrent_constraint: Constraint function applied to
2439      the `recurrent_kernel` weights matrix.
2440    bias_constraint: Constraint function applied to the bias vector.
2441    dropout: Float between 0 and 1.
2442      Fraction of the units to drop for
2443      the linear transformation of the inputs.
2444    recurrent_dropout: Float between 0 and 1.
2445      Fraction of the units to drop for
2446      the linear transformation of the recurrent state.
2447    implementation: Implementation mode, either 1 or 2.
2448      Mode 1 will structure its operations as a larger number of
2449      smaller dot products and additions, whereas mode 2 will
2450      batch them into fewer, larger operations. These modes will
2451      have different performance profiles on different hardware and
2452      for different applications.
2453    return_sequences: Boolean. Whether to return the last output.
2454      in the output sequence, or the full sequence.
2455    return_state: Boolean. Whether to return the last state
2456      in addition to the output.
2457    go_backwards: Boolean (default False).
2458      If True, process the input sequence backwards and return the
2459      reversed sequence.
2460    stateful: Boolean (default False). If True, the last state
2461      for each sample at index i in a batch will be used as initial
2462      state for the sample of index i in the following batch.
2463    unroll: Boolean (default False).
2464      If True, the network will be unrolled,
2465      else a symbolic loop will be used.
2466      Unrolling can speed-up a RNN,
2467      although it tends to be more memory-intensive.
2468      Unrolling is only suitable for short sequences.
2469
2470  Call arguments:
2471    inputs: A 3D tensor.
2472    mask: Binary tensor of shape `(samples, timesteps)` indicating whether
2473      a given timestep should be masked.
2474    training: Python boolean indicating whether the layer should behave in
2475      training mode or in inference mode. This argument is passed to the cell
2476      when calling it. This is only relevant if `dropout` or
2477      `recurrent_dropout` is used.
2478    initial_state: List of initial state tensors to be passed to the first
2479      call of the cell.
2480  """
2481
2482  def __init__(self,
2483               units,
2484               activation='tanh',
2485               recurrent_activation='hard_sigmoid',
2486               use_bias=True,
2487               kernel_initializer='glorot_uniform',
2488               recurrent_initializer='orthogonal',
2489               bias_initializer='zeros',
2490               unit_forget_bias=True,
2491               kernel_regularizer=None,
2492               recurrent_regularizer=None,
2493               bias_regularizer=None,
2494               activity_regularizer=None,
2495               kernel_constraint=None,
2496               recurrent_constraint=None,
2497               bias_constraint=None,
2498               dropout=0.,
2499               recurrent_dropout=0.,
2500               implementation=1,
2501               return_sequences=False,
2502               return_state=False,
2503               go_backwards=False,
2504               stateful=False,
2505               unroll=False,
2506               **kwargs):
2507    if implementation == 0:
2508      logging.warning('`implementation=0` has been deprecated, '
2509                      'and now defaults to `implementation=1`.'
2510                      'Please update your layer call.')
2511    cell = LSTMCell(
2512        units,
2513        activation=activation,
2514        recurrent_activation=recurrent_activation,
2515        use_bias=use_bias,
2516        kernel_initializer=kernel_initializer,
2517        recurrent_initializer=recurrent_initializer,
2518        unit_forget_bias=unit_forget_bias,
2519        bias_initializer=bias_initializer,
2520        kernel_regularizer=kernel_regularizer,
2521        recurrent_regularizer=recurrent_regularizer,
2522        bias_regularizer=bias_regularizer,
2523        kernel_constraint=kernel_constraint,
2524        recurrent_constraint=recurrent_constraint,
2525        bias_constraint=bias_constraint,
2526        dropout=dropout,
2527        recurrent_dropout=recurrent_dropout,
2528        implementation=implementation)
2529    super(LSTM, self).__init__(
2530        cell,
2531        return_sequences=return_sequences,
2532        return_state=return_state,
2533        go_backwards=go_backwards,
2534        stateful=stateful,
2535        unroll=unroll,
2536        **kwargs)
2537    self.activity_regularizer = regularizers.get(activity_regularizer)
2538    self.input_spec = [InputSpec(ndim=3)]
2539
2540  def call(self, inputs, mask=None, training=None, initial_state=None):
2541    self.cell.reset_dropout_mask()
2542    self.cell.reset_recurrent_dropout_mask()
2543    return super(LSTM, self).call(
2544        inputs, mask=mask, training=training, initial_state=initial_state)
2545
2546  @property
2547  def units(self):
2548    return self.cell.units
2549
2550  @property
2551  def activation(self):
2552    return self.cell.activation
2553
2554  @property
2555  def recurrent_activation(self):
2556    return self.cell.recurrent_activation
2557
2558  @property
2559  def use_bias(self):
2560    return self.cell.use_bias
2561
2562  @property
2563  def kernel_initializer(self):
2564    return self.cell.kernel_initializer
2565
2566  @property
2567  def recurrent_initializer(self):
2568    return self.cell.recurrent_initializer
2569
2570  @property
2571  def bias_initializer(self):
2572    return self.cell.bias_initializer
2573
2574  @property
2575  def unit_forget_bias(self):
2576    return self.cell.unit_forget_bias
2577
2578  @property
2579  def kernel_regularizer(self):
2580    return self.cell.kernel_regularizer
2581
2582  @property
2583  def recurrent_regularizer(self):
2584    return self.cell.recurrent_regularizer
2585
2586  @property
2587  def bias_regularizer(self):
2588    return self.cell.bias_regularizer
2589
2590  @property
2591  def kernel_constraint(self):
2592    return self.cell.kernel_constraint
2593
2594  @property
2595  def recurrent_constraint(self):
2596    return self.cell.recurrent_constraint
2597
2598  @property
2599  def bias_constraint(self):
2600    return self.cell.bias_constraint
2601
2602  @property
2603  def dropout(self):
2604    return self.cell.dropout
2605
2606  @property
2607  def recurrent_dropout(self):
2608    return self.cell.recurrent_dropout
2609
2610  @property
2611  def implementation(self):
2612    return self.cell.implementation
2613
2614  def get_config(self):
2615    config = {
2616        'units':
2617            self.units,
2618        'activation':
2619            activations.serialize(self.activation),
2620        'recurrent_activation':
2621            activations.serialize(self.recurrent_activation),
2622        'use_bias':
2623            self.use_bias,
2624        'kernel_initializer':
2625            initializers.serialize(self.kernel_initializer),
2626        'recurrent_initializer':
2627            initializers.serialize(self.recurrent_initializer),
2628        'bias_initializer':
2629            initializers.serialize(self.bias_initializer),
2630        'unit_forget_bias':
2631            self.unit_forget_bias,
2632        'kernel_regularizer':
2633            regularizers.serialize(self.kernel_regularizer),
2634        'recurrent_regularizer':
2635            regularizers.serialize(self.recurrent_regularizer),
2636        'bias_regularizer':
2637            regularizers.serialize(self.bias_regularizer),
2638        'activity_regularizer':
2639            regularizers.serialize(self.activity_regularizer),
2640        'kernel_constraint':
2641            constraints.serialize(self.kernel_constraint),
2642        'recurrent_constraint':
2643            constraints.serialize(self.recurrent_constraint),
2644        'bias_constraint':
2645            constraints.serialize(self.bias_constraint),
2646        'dropout':
2647            self.dropout,
2648        'recurrent_dropout':
2649            self.recurrent_dropout,
2650        'implementation':
2651            self.implementation
2652    }
2653    base_config = super(LSTM, self).get_config()
2654    del base_config['cell']
2655    return dict(list(base_config.items()) + list(config.items()))
2656
2657  @classmethod
2658  def from_config(cls, config):
2659    if 'implementation' in config and config['implementation'] == 0:
2660      config['implementation'] = 1
2661    return cls(**config)
2662
2663
2664def _generate_dropout_mask(ones, rate, training=None, count=1):
2665  def dropped_inputs():
2666    return K.dropout(ones, rate)
2667
2668  if count > 1:
2669    return [
2670        K.in_train_phase(dropped_inputs, ones, training=training)
2671        for _ in range(count)
2672    ]
2673  return K.in_train_phase(dropped_inputs, ones, training=training)
2674
2675
2676def _standardize_args(
2677    inputs, initial_state, constants, num_constants, num_inputs=1):
2678  """Standardizes `__call__` to a single list of tensor inputs.
2679
2680  When running a model loaded from a file, the input tensors
2681  `initial_state` and `constants` can be passed to `RNN.__call__()` as part
2682  of `inputs` instead of by the dedicated keyword arguments. This method
2683  makes sure the arguments are separated and that `initial_state` and
2684  `constants` are lists of tensors (or None).
2685
2686  Arguments:
2687    inputs: Tensor or list/tuple of tensors. which may include constants
2688      and initial states. In that case `num_constant` must be specified.
2689    initial_state: Tensor or list of tensors or None, initial states.
2690    constants: Tensor or list of tensors or None, constant tensors.
2691    num_constants: Expected number of constants (if constants are passed as
2692      part of the `inputs` list.
2693    num_inputs: Expected number of real input tensors (exclude initial_states
2694      and constants).
2695
2696  Returns:
2697    inputs: Single tensor or tuple of tensors.
2698    initial_state: List of tensors or None.
2699    constants: List of tensors or None.
2700  """
2701  if isinstance(inputs, list):
2702    # There are several situations here:
2703    # In the graph mode, __call__ will be only called once. The initial_state
2704    # and constants could be in inputs (from file loading).
2705    # In the eager mode, __call__ will be called twice, once during
2706    # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
2707    # model.fit/train_on_batch/predict with real np data. In the second case,
2708    # the inputs will contain initial_state and constants, and more importantly,
2709    # the real inputs will be in a flat list, instead of nested tuple.
2710    #
2711    # For either case, we will use num_inputs to split the input list, and
2712    # restructure the real input into tuple.
2713    assert initial_state is None and constants is None
2714    inputs = nest.flatten(inputs)
2715    if num_constants is not None:
2716      constants = inputs[-num_constants:]
2717      inputs = inputs[:-num_constants]
2718    if num_inputs is None:
2719      num_inputs = 1
2720    if len(inputs) > num_inputs:
2721      initial_state = inputs[num_inputs:]
2722      inputs = inputs[:num_inputs]
2723
2724    if len(inputs) > 1:
2725      inputs = tuple(inputs)
2726    else:
2727      inputs = inputs[0]
2728
2729  def to_list_or_none(x):
2730    if x is None or isinstance(x, list):
2731      return x
2732    if isinstance(x, tuple):
2733      return list(x)
2734    return [x]
2735
2736  initial_state = to_list_or_none(initial_state)
2737  constants = to_list_or_none(constants)
2738
2739  return inputs, initial_state, constants
2740
2741
2742def _is_multiple_state(state_size):
2743  """Check whether the state_size contains multiple states."""
2744  return (hasattr(state_size, '__len__') and
2745          not isinstance(state_size, tensor_shape.TensorShape))
2746
2747
2748def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
2749  if inputs is not None:
2750    batch_size = array_ops.shape(inputs)[0]
2751    dtype = inputs.dtype
2752  return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
2753
2754
2755def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
2756  """Generate a zero filled tensor with shape [batch_size, state_size]."""
2757  if batch_size_tensor is None or dtype is None:
2758    raise ValueError(
2759        'batch_size and dtype cannot be None while constructing initial state: '
2760        'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
2761
2762  def create_zeros(unnested_state_size):
2763    flat_dims = tensor_shape.as_shape(unnested_state_size).as_list()
2764    init_state_size = [batch_size_tensor] + flat_dims
2765    return array_ops.zeros(init_state_size, dtype=dtype)
2766
2767  if nest.is_sequence(state_size):
2768    return nest.map_structure(create_zeros, state_size)
2769  else:
2770    return create_zeros(state_size)
2771