• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module implementing RNN Cells.
16
17This module provides a number of basic commonly used RNN cells, such as LSTM
18(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
19operators that allow adding dropouts, projections, or embeddings for inputs.
20Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
21calling the `rnn` ops several times.
22"""
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import collections
28import hashlib
29import numbers
30
31from tensorflow.python.eager import context
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.keras import activations
38from tensorflow.python.keras import initializers
39from tensorflow.python.keras import layers as keras_layer
40from tensorflow.python.keras.engine import input_spec
41from tensorflow.python.keras.utils import tf_utils
42from tensorflow.python.layers import base as base_layer
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import clip_ops
45from tensorflow.python.ops import init_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import nn_ops
48from tensorflow.python.ops import partitioned_variables
49from tensorflow.python.ops import random_ops
50from tensorflow.python.ops import tensor_array_ops
51from tensorflow.python.ops import variable_scope as vs
52from tensorflow.python.ops import variables as tf_variables
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.training.tracking import base as trackable
55from tensorflow.python.util import nest
56from tensorflow.python.util.deprecation import deprecated
57from tensorflow.python.util.tf_export import tf_export
58
59
60_BIAS_VARIABLE_NAME = "bias"
61_WEIGHTS_VARIABLE_NAME = "kernel"
62
63# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
64ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
65
66
67def _hasattr(obj, attr_name):
68  try:
69    getattr(obj, attr_name)
70  except AttributeError:
71    return False
72  else:
73    return True
74
75
76def assert_like_rnncell(cell_name, cell):
77  """Raises a TypeError if cell is not like an RNNCell.
78
79  NOTE: Do not rely on the error message (in particular in tests) which can be
80  subject to change to increase readability. Use
81  ASSERT_LIKE_RNNCELL_ERROR_REGEXP.
82
83  Args:
84    cell_name: A string to give a meaningful error referencing to the name
85      of the functionargument.
86    cell: The object which should behave like an RNNCell.
87
88  Raises:
89    TypeError: A human-friendly exception.
90  """
91  conditions = [
92      _hasattr(cell, "output_size"),
93      _hasattr(cell, "state_size"),
94      _hasattr(cell, "get_initial_state") or _hasattr(cell, "zero_state"),
95      callable(cell),
96  ]
97  errors = [
98      "'output_size' property is missing",
99      "'state_size' property is missing",
100      "either 'zero_state' or 'get_initial_state' method is required",
101      "is not callable"
102  ]
103
104  if not all(conditions):
105
106    errors = [error for error, cond in zip(errors, conditions) if not cond]
107    raise TypeError("The argument {!r} ({}) is not an RNNCell: {}.".format(
108        cell_name, cell, ", ".join(errors)))
109
110
111def _concat(prefix, suffix, static=False):
112  """Concat that enables int, Tensor, or TensorShape values.
113
114  This function takes a size specification, which can be an integer, a
115  TensorShape, or a Tensor, and converts it into a concatenated Tensor
116  (if static = False) or a list of integers (if static = True).
117
118  Args:
119    prefix: The prefix; usually the batch size (and/or time step size).
120      (TensorShape, int, or Tensor.)
121    suffix: TensorShape, int, or Tensor.
122    static: If `True`, return a python list with possibly unknown dimensions.
123      Otherwise return a `Tensor`.
124
125  Returns:
126    shape: the concatenation of prefix and suffix.
127
128  Raises:
129    ValueError: if `suffix` is not a scalar or vector (or TensorShape).
130    ValueError: if prefix or suffix was `None` and asked for dynamic
131      Tensors out.
132  """
133  if isinstance(prefix, ops.Tensor):
134    p = prefix
135    p_static = tensor_util.constant_value(prefix)
136    if p.shape.ndims == 0:
137      p = array_ops.expand_dims(p, 0)
138    elif p.shape.ndims != 1:
139      raise ValueError("prefix tensor must be either a scalar or vector, "
140                       "but saw tensor: %s" % p)
141  else:
142    p = tensor_shape.as_shape(prefix)
143    p_static = p.as_list() if p.ndims is not None else None
144    p = (constant_op.constant(p.as_list(), dtype=dtypes.int32)
145         if p.is_fully_defined() else None)
146  if isinstance(suffix, ops.Tensor):
147    s = suffix
148    s_static = tensor_util.constant_value(suffix)
149    if s.shape.ndims == 0:
150      s = array_ops.expand_dims(s, 0)
151    elif s.shape.ndims != 1:
152      raise ValueError("suffix tensor must be either a scalar or vector, "
153                       "but saw tensor: %s" % s)
154  else:
155    s = tensor_shape.as_shape(suffix)
156    s_static = s.as_list() if s.ndims is not None else None
157    s = (constant_op.constant(s.as_list(), dtype=dtypes.int32)
158         if s.is_fully_defined() else None)
159
160  if static:
161    shape = tensor_shape.as_shape(p_static).concatenate(s_static)
162    shape = shape.as_list() if shape.ndims is not None else None
163  else:
164    if p is None or s is None:
165      raise ValueError("Provided a prefix or suffix of None: %s and %s"
166                       % (prefix, suffix))
167    shape = array_ops.concat((p, s), 0)
168  return shape
169
170
171def _zero_state_tensors(state_size, batch_size, dtype):
172  """Create tensors of zeros based on state_size, batch_size, and dtype."""
173  def get_state_shape(s):
174    """Combine s with batch_size to get a proper tensor shape."""
175    c = _concat(batch_size, s)
176    size = array_ops.zeros(c, dtype=dtype)
177    if not context.executing_eagerly():
178      c_static = _concat(batch_size, s, static=True)
179      size.set_shape(c_static)
180    return size
181  return nest.map_structure(get_state_shape, state_size)
182
183
184@tf_export(v1=["nn.rnn_cell.RNNCell"])
185class RNNCell(base_layer.Layer):
186  """Abstract object representing an RNN cell.
187
188  Every `RNNCell` must have the properties below and implement `call` with
189  the signature `(output, next_state) = call(input, state)`.  The optional
190  third input argument, `scope`, is allowed for backwards compatibility
191  purposes; but should be left off for new subclasses.
192
193  This definition of cell differs from the definition used in the literature.
194  In the literature, 'cell' refers to an object with a single scalar output.
195  This definition refers to a horizontal array of such units.
196
197  An RNN cell, in the most abstract setting, is anything that has
198  a state and performs some operation that takes a matrix of inputs.
199  This operation results in an output matrix with `self.output_size` columns.
200  If `self.state_size` is an integer, this operation also results in a new
201  state matrix with `self.state_size` columns.  If `self.state_size` is a
202  (possibly nested tuple of) TensorShape object(s), then it should return a
203  matching structure of Tensors having shape `[batch_size].concatenate(s)`
204  for each `s` in `self.batch_size`.
205  """
206
207  def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
208    super(RNNCell, self).__init__(
209        trainable=trainable, name=name, dtype=dtype, **kwargs)
210    # Attribute that indicates whether the cell is a TF RNN cell, due the slight
211    # difference between TF and Keras RNN cell. Notably the state is not wrapped
212    # in a list for TF cell where they are single tensor state, whereas keras
213    # cell will wrap the state into a list, and call() will have to unwrap them.
214    self._is_tf_rnn_cell = True
215
216  def __call__(self, inputs, state, scope=None):
217    """Run this RNN cell on inputs, starting from the given state.
218
219    Args:
220      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
221      state: if `self.state_size` is an integer, this should be a `2-D Tensor`
222        with shape `[batch_size, self.state_size]`.  Otherwise, if
223        `self.state_size` is a tuple of integers, this should be a tuple
224        with shapes `[batch_size, s] for s in self.state_size`.
225      scope: VariableScope for the created subgraph; defaults to class name.
226
227    Returns:
228      A pair containing:
229
230      - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
231      - New state: Either a single `2-D` tensor, or a tuple of tensors matching
232        the arity and shapes of `state`.
233    """
234    if scope is not None:
235      with vs.variable_scope(scope,
236                             custom_getter=self._rnn_get_variable) as scope:
237        return super(RNNCell, self).__call__(inputs, state, scope=scope)
238    else:
239      scope_attrname = "rnncell_scope"
240      scope = getattr(self, scope_attrname, None)
241      if scope is None:
242        scope = vs.variable_scope(vs.get_variable_scope(),
243                                  custom_getter=self._rnn_get_variable)
244        setattr(self, scope_attrname, scope)
245      with scope:
246        return super(RNNCell, self).__call__(inputs, state)
247
248  def _rnn_get_variable(self, getter, *args, **kwargs):
249    variable = getter(*args, **kwargs)
250    if context.executing_eagerly():
251      trainable = variable._trainable  # pylint: disable=protected-access
252    else:
253      trainable = (
254          variable in tf_variables.trainable_variables() or
255          (isinstance(variable, tf_variables.PartitionedVariable) and
256           list(variable)[0] in tf_variables.trainable_variables()))
257    if trainable and variable not in self._trainable_weights:
258      self._trainable_weights.append(variable)
259    elif not trainable and variable not in self._non_trainable_weights:
260      self._non_trainable_weights.append(variable)
261    return variable
262
263  @property
264  def state_size(self):
265    """size(s) of state(s) used by this cell.
266
267    It can be represented by an Integer, a TensorShape or a tuple of Integers
268    or TensorShapes.
269    """
270    raise NotImplementedError("Abstract method")
271
272  @property
273  def output_size(self):
274    """Integer or TensorShape: size of outputs produced by this cell."""
275    raise NotImplementedError("Abstract method")
276
277  def build(self, _):
278    # This tells the parent Layer object that it's OK to call
279    # self.add_variable() inside the call() method.
280    pass
281
282  def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
283    if inputs is not None:
284      # Validate the given batch_size and dtype against inputs if provided.
285      inputs = ops.convert_to_tensor(inputs, name="inputs")
286      if batch_size is not None:
287        if tensor_util.is_tensor(batch_size):
288          static_batch_size = tensor_util.constant_value(
289              batch_size, partial=True)
290        else:
291          static_batch_size = batch_size
292        if inputs.shape.dims[0].value != static_batch_size:
293          raise ValueError(
294              "batch size from input tensor is different from the "
295              "input param. Input tensor batch: {}, batch_size: {}".format(
296                  inputs.shape.dims[0].value, batch_size))
297
298      if dtype is not None and inputs.dtype != dtype:
299        raise ValueError(
300            "dtype from input tensor is different from the "
301            "input param. Input tensor dtype: {}, dtype: {}".format(
302                inputs.dtype, dtype))
303
304      batch_size = inputs.shape.dims[0].value or array_ops.shape(inputs)[0]
305      dtype = inputs.dtype
306    if None in [batch_size, dtype]:
307      raise ValueError(
308          "batch_size and dtype cannot be None while constructing initial "
309          "state: batch_size={}, dtype={}".format(batch_size, dtype))
310    return self.zero_state(batch_size, dtype)
311
312  def zero_state(self, batch_size, dtype):
313    """Return zero-filled state tensor(s).
314
315    Args:
316      batch_size: int, float, or unit Tensor representing the batch size.
317      dtype: the data type to use for the state.
318
319    Returns:
320      If `state_size` is an int or TensorShape, then the return value is a
321      `N-D` tensor of shape `[batch_size, state_size]` filled with zeros.
322
323      If `state_size` is a nested list or tuple, then the return value is
324      a nested list or tuple (of the same structure) of `2-D` tensors with
325      the shapes `[batch_size, s]` for each s in `state_size`.
326    """
327    # Try to use the last cached zero_state. This is done to avoid recreating
328    # zeros, especially when eager execution is enabled.
329    state_size = self.state_size
330    is_eager = context.executing_eagerly()
331    if is_eager and _hasattr(self, "_last_zero_state"):
332      (last_state_size, last_batch_size, last_dtype,
333       last_output) = getattr(self, "_last_zero_state")
334      if (last_batch_size == batch_size and
335          last_dtype == dtype and
336          last_state_size == state_size):
337        return last_output
338    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
339      output = _zero_state_tensors(state_size, batch_size, dtype)
340    if is_eager:
341      self._last_zero_state = (state_size, batch_size, dtype, output)
342    return output
343
344
345class LayerRNNCell(RNNCell):
346  """Subclass of RNNCells that act like proper `tf.Layer` objects.
347
348  For backwards compatibility purposes, most `RNNCell` instances allow their
349  `call` methods to instantiate variables via `tf.get_variable`.  The underlying
350  variable scope thus keeps track of any variables, and returning cached
351  versions.  This is atypical of `tf.layer` objects, which separate this
352  part of layer building into a `build` method that is only called once.
353
354  Here we provide a subclass for `RNNCell` objects that act exactly as
355  `Layer` objects do.  They must provide a `build` method and their
356  `call` methods do not access Variables `tf.get_variable`.
357  """
358
359  def __call__(self, inputs, state, scope=None, *args, **kwargs):
360    """Run this RNN cell on inputs, starting from the given state.
361
362    Args:
363      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
364      state: if `self.state_size` is an integer, this should be a `2-D Tensor`
365        with shape `[batch_size, self.state_size]`.  Otherwise, if
366        `self.state_size` is a tuple of integers, this should be a tuple
367        with shapes `[batch_size, s] for s in self.state_size`.
368      scope: optional cell scope.
369      *args: Additional positional arguments.
370      **kwargs: Additional keyword arguments.
371
372    Returns:
373      A pair containing:
374
375      - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
376      - New state: Either a single `2-D` tensor, or a tuple of tensors matching
377        the arity and shapes of `state`.
378    """
379    # Bypass RNNCell's variable capturing semantics for LayerRNNCell.
380    # Instead, it is up to subclasses to provide a proper build
381    # method.  See the class docstring for more details.
382    return base_layer.Layer.__call__(self, inputs, state, scope=scope,
383                                     *args, **kwargs)
384
385
386@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
387class BasicRNNCell(LayerRNNCell):
388  """The most basic RNN cell.
389
390  Note that this cell is not optimized for performance. Please use
391  `tf.contrib.cudnn_rnn.CudnnRNNTanh` for better performance on GPU.
392
393  Args:
394    num_units: int, The number of units in the RNN cell.
395    activation: Nonlinearity to use.  Default: `tanh`. It could also be string
396      that is within Keras activation function names.
397    reuse: (optional) Python boolean describing whether to reuse variables
398     in an existing scope.  If not `True`, and the existing scope already has
399     the given variables, an error is raised.
400    name: String, the name of the layer. Layers with the same name will
401      share weights, but to avoid mistakes we require reuse=True in such
402      cases.
403    dtype: Default dtype of the layer (default of `None` means use the type
404      of the first input). Required when `build` is called before `call`.
405    **kwargs: Dict, keyword named properties for common layer attributes, like
406      `trainable` etc when constructing the cell from configs of get_config().
407  """
408
409  @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell,"
410                    " and will be replaced by that in Tensorflow 2.0.")
411  def __init__(self,
412               num_units,
413               activation=None,
414               reuse=None,
415               name=None,
416               dtype=None,
417               **kwargs):
418    super(BasicRNNCell, self).__init__(
419        _reuse=reuse, name=name, dtype=dtype, **kwargs)
420    _check_supported_dtypes(self.dtype)
421    if context.executing_eagerly() and context.num_gpus() > 0:
422      logging.warn("%s: Note that this cell is not optimized for performance. "
423                   "Please use tf.contrib.cudnn_rnn.CudnnRNNTanh for better "
424                   "performance on GPU.", self)
425
426    # Inputs must be 2-dimensional.
427    self.input_spec = input_spec.InputSpec(ndim=2)
428
429    self._num_units = num_units
430    if activation:
431      self._activation = activations.get(activation)
432    else:
433      self._activation = math_ops.tanh
434
435  @property
436  def state_size(self):
437    return self._num_units
438
439  @property
440  def output_size(self):
441    return self._num_units
442
443  @tf_utils.shape_type_conversion
444  def build(self, inputs_shape):
445    if inputs_shape[-1] is None:
446      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
447                       % str(inputs_shape))
448    _check_supported_dtypes(self.dtype)
449
450    input_depth = inputs_shape[-1]
451    self._kernel = self.add_variable(
452        _WEIGHTS_VARIABLE_NAME,
453        shape=[input_depth + self._num_units, self._num_units])
454    self._bias = self.add_variable(
455        _BIAS_VARIABLE_NAME,
456        shape=[self._num_units],
457        initializer=init_ops.zeros_initializer(dtype=self.dtype))
458
459    self.built = True
460
461  def call(self, inputs, state):
462    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
463    _check_rnn_cell_input_dtypes([inputs, state])
464    gate_inputs = math_ops.matmul(
465        array_ops.concat([inputs, state], 1), self._kernel)
466    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
467    output = self._activation(gate_inputs)
468    return output, output
469
470  def get_config(self):
471    config = {
472        "num_units": self._num_units,
473        "activation": activations.serialize(self._activation),
474        "reuse": self._reuse,
475    }
476    base_config = super(BasicRNNCell, self).get_config()
477    return dict(list(base_config.items()) + list(config.items()))
478
479
480@tf_export(v1=["nn.rnn_cell.GRUCell"])
481class GRUCell(LayerRNNCell):
482  """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
483
484  Note that this cell is not optimized for performance. Please use
485  `tf.contrib.cudnn_rnn.CudnnGRU` for better performance on GPU, or
486  `tf.contrib.rnn.GRUBlockCellV2` for better performance on CPU.
487
488  Args:
489    num_units: int, The number of units in the GRU cell.
490    activation: Nonlinearity to use.  Default: `tanh`.
491    reuse: (optional) Python boolean describing whether to reuse variables
492     in an existing scope.  If not `True`, and the existing scope already has
493     the given variables, an error is raised.
494    kernel_initializer: (optional) The initializer to use for the weight and
495    projection matrices.
496    bias_initializer: (optional) The initializer to use for the bias.
497    name: String, the name of the layer. Layers with the same name will
498      share weights, but to avoid mistakes we require reuse=True in such
499      cases.
500    dtype: Default dtype of the layer (default of `None` means use the type
501      of the first input). Required when `build` is called before `call`.
502    **kwargs: Dict, keyword named properties for common layer attributes, like
503      `trainable` etc when constructing the cell from configs of get_config().
504  """
505
506  @deprecated(None, "This class is equivalent as tf.keras.layers.GRUCell,"
507                    " and will be replaced by that in Tensorflow 2.0.")
508  def __init__(self,
509               num_units,
510               activation=None,
511               reuse=None,
512               kernel_initializer=None,
513               bias_initializer=None,
514               name=None,
515               dtype=None,
516               **kwargs):
517    super(GRUCell, self).__init__(
518        _reuse=reuse, name=name, dtype=dtype, **kwargs)
519    _check_supported_dtypes(self.dtype)
520
521    if context.executing_eagerly() and context.num_gpus() > 0:
522      logging.warn("%s: Note that this cell is not optimized for performance. "
523                   "Please use tf.contrib.cudnn_rnn.CudnnGRU for better "
524                   "performance on GPU.", self)
525    # Inputs must be 2-dimensional.
526    self.input_spec = input_spec.InputSpec(ndim=2)
527
528    self._num_units = num_units
529    if activation:
530      self._activation = activations.get(activation)
531    else:
532      self._activation = math_ops.tanh
533    self._kernel_initializer = initializers.get(kernel_initializer)
534    self._bias_initializer = initializers.get(bias_initializer)
535
536  @property
537  def state_size(self):
538    return self._num_units
539
540  @property
541  def output_size(self):
542    return self._num_units
543
544  @tf_utils.shape_type_conversion
545  def build(self, inputs_shape):
546    if inputs_shape[-1] is None:
547      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
548                       % str(inputs_shape))
549    _check_supported_dtypes(self.dtype)
550    input_depth = inputs_shape[-1]
551    self._gate_kernel = self.add_variable(
552        "gates/%s" % _WEIGHTS_VARIABLE_NAME,
553        shape=[input_depth + self._num_units, 2 * self._num_units],
554        initializer=self._kernel_initializer)
555    self._gate_bias = self.add_variable(
556        "gates/%s" % _BIAS_VARIABLE_NAME,
557        shape=[2 * self._num_units],
558        initializer=(
559            self._bias_initializer
560            if self._bias_initializer is not None
561            else init_ops.constant_initializer(1.0, dtype=self.dtype)))
562    self._candidate_kernel = self.add_variable(
563        "candidate/%s" % _WEIGHTS_VARIABLE_NAME,
564        shape=[input_depth + self._num_units, self._num_units],
565        initializer=self._kernel_initializer)
566    self._candidate_bias = self.add_variable(
567        "candidate/%s" % _BIAS_VARIABLE_NAME,
568        shape=[self._num_units],
569        initializer=(
570            self._bias_initializer
571            if self._bias_initializer is not None
572            else init_ops.zeros_initializer(dtype=self.dtype)))
573
574    self.built = True
575
576  def call(self, inputs, state):
577    """Gated recurrent unit (GRU) with nunits cells."""
578    _check_rnn_cell_input_dtypes([inputs, state])
579
580    gate_inputs = math_ops.matmul(
581        array_ops.concat([inputs, state], 1), self._gate_kernel)
582    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
583
584    value = math_ops.sigmoid(gate_inputs)
585    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
586
587    r_state = r * state
588
589    candidate = math_ops.matmul(
590        array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
591    candidate = nn_ops.bias_add(candidate, self._candidate_bias)
592
593    c = self._activation(candidate)
594    new_h = u * state + (1 - u) * c
595    return new_h, new_h
596
597  def get_config(self):
598    config = {
599        "num_units": self._num_units,
600        "kernel_initializer": initializers.serialize(self._kernel_initializer),
601        "bias_initializer": initializers.serialize(self._bias_initializer),
602        "activation": activations.serialize(self._activation),
603        "reuse": self._reuse,
604    }
605    base_config = super(GRUCell, self).get_config()
606    return dict(list(base_config.items()) + list(config.items()))
607
608
609_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
610
611
612@tf_export(v1=["nn.rnn_cell.LSTMStateTuple"])
613class LSTMStateTuple(_LSTMStateTuple):
614  """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
615
616  Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
617  and `h` is the output.
618
619  Only used when `state_is_tuple=True`.
620  """
621  __slots__ = ()
622
623  @property
624  def dtype(self):
625    (c, h) = self
626    if c.dtype != h.dtype:
627      raise TypeError("Inconsistent internal state: %s vs %s" %
628                      (str(c.dtype), str(h.dtype)))
629    return c.dtype
630
631
632@tf_export(v1=["nn.rnn_cell.BasicLSTMCell"])
633class BasicLSTMCell(LayerRNNCell):
634  """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
635
636  Basic LSTM recurrent network cell.
637
638  The implementation is based on: http://arxiv.org/abs/1409.2329.
639
640  We add forget_bias (default: 1) to the biases of the forget gate in order to
641  reduce the scale of forgetting in the beginning of the training.
642
643  It does not allow cell clipping, a projection layer, and does not
644  use peep-hole connections: it is the basic baseline.
645
646  For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell`
647  that follows.
648
649  Note that this cell is not optimized for performance. Please use
650  `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
651  `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
652  better performance on CPU.
653  """
654
655  @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell,"
656                    " and will be replaced by that in Tensorflow 2.0.")
657  def __init__(self,
658               num_units,
659               forget_bias=1.0,
660               state_is_tuple=True,
661               activation=None,
662               reuse=None,
663               name=None,
664               dtype=None,
665               **kwargs):
666    """Initialize the basic LSTM cell.
667
668    Args:
669      num_units: int, The number of units in the LSTM cell.
670      forget_bias: float, The bias added to forget gates (see above).
671        Must set to `0.0` manually when restoring from CudnnLSTM-trained
672        checkpoints.
673      state_is_tuple: If True, accepted and returned states are 2-tuples of
674        the `c_state` and `m_state`.  If False, they are concatenated
675        along the column axis.  The latter behavior will soon be deprecated.
676      activation: Activation function of the inner states.  Default: `tanh`. It
677        could also be string that is within Keras activation function names.
678      reuse: (optional) Python boolean describing whether to reuse variables
679        in an existing scope.  If not `True`, and the existing scope already has
680        the given variables, an error is raised.
681      name: String, the name of the layer. Layers with the same name will
682        share weights, but to avoid mistakes we require reuse=True in such
683        cases.
684      dtype: Default dtype of the layer (default of `None` means use the type
685        of the first input). Required when `build` is called before `call`.
686      **kwargs: Dict, keyword named properties for common layer attributes, like
687        `trainable` etc when constructing the cell from configs of get_config().
688
689      When restoring from CudnnLSTM-trained checkpoints, must use
690      `CudnnCompatibleLSTMCell` instead.
691    """
692    super(BasicLSTMCell, self).__init__(
693        _reuse=reuse, name=name, dtype=dtype, **kwargs)
694    _check_supported_dtypes(self.dtype)
695    if not state_is_tuple:
696      logging.warn("%s: Using a concatenated state is slower and will soon be "
697                   "deprecated.  Use state_is_tuple=True.", self)
698    if context.executing_eagerly() and context.num_gpus() > 0:
699      logging.warn("%s: Note that this cell is not optimized for performance. "
700                   "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
701                   "performance on GPU.", self)
702
703    # Inputs must be 2-dimensional.
704    self.input_spec = input_spec.InputSpec(ndim=2)
705
706    self._num_units = num_units
707    self._forget_bias = forget_bias
708    self._state_is_tuple = state_is_tuple
709    if activation:
710      self._activation = activations.get(activation)
711    else:
712      self._activation = math_ops.tanh
713
714  @property
715  def state_size(self):
716    return (LSTMStateTuple(self._num_units, self._num_units)
717            if self._state_is_tuple else 2 * self._num_units)
718
719  @property
720  def output_size(self):
721    return self._num_units
722
723  @tf_utils.shape_type_conversion
724  def build(self, inputs_shape):
725    if inputs_shape[-1] is None:
726      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
727                       % str(inputs_shape))
728    _check_supported_dtypes(self.dtype)
729    input_depth = inputs_shape[-1]
730    h_depth = self._num_units
731    self._kernel = self.add_variable(
732        _WEIGHTS_VARIABLE_NAME,
733        shape=[input_depth + h_depth, 4 * self._num_units])
734    self._bias = self.add_variable(
735        _BIAS_VARIABLE_NAME,
736        shape=[4 * self._num_units],
737        initializer=init_ops.zeros_initializer(dtype=self.dtype))
738
739    self.built = True
740
741  def call(self, inputs, state):
742    """Long short-term memory cell (LSTM).
743
744    Args:
745      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
746      state: An `LSTMStateTuple` of state tensors, each shaped
747        `[batch_size, num_units]`, if `state_is_tuple` has been set to
748        `True`.  Otherwise, a `Tensor` shaped
749        `[batch_size, 2 * num_units]`.
750
751    Returns:
752      A pair containing the new hidden state, and the new state (either a
753        `LSTMStateTuple` or a concatenated state, depending on
754        `state_is_tuple`).
755    """
756    _check_rnn_cell_input_dtypes([inputs, state])
757
758    sigmoid = math_ops.sigmoid
759    one = constant_op.constant(1, dtype=dtypes.int32)
760    # Parameters of gates are concatenated into one multiply for efficiency.
761    if self._state_is_tuple:
762      c, h = state
763    else:
764      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
765
766    gate_inputs = math_ops.matmul(
767        array_ops.concat([inputs, h], 1), self._kernel)
768    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
769
770    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
771    i, j, f, o = array_ops.split(
772        value=gate_inputs, num_or_size_splits=4, axis=one)
773
774    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
775    # Note that using `add` and `multiply` instead of `+` and `*` gives a
776    # performance improvement. So using those at the cost of readability.
777    add = math_ops.add
778    multiply = math_ops.multiply
779    new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
780                multiply(sigmoid(i), self._activation(j)))
781    new_h = multiply(self._activation(new_c), sigmoid(o))
782
783    if self._state_is_tuple:
784      new_state = LSTMStateTuple(new_c, new_h)
785    else:
786      new_state = array_ops.concat([new_c, new_h], 1)
787    return new_h, new_state
788
789  def get_config(self):
790    config = {
791        "num_units": self._num_units,
792        "forget_bias": self._forget_bias,
793        "state_is_tuple": self._state_is_tuple,
794        "activation": activations.serialize(self._activation),
795        "reuse": self._reuse,
796    }
797    base_config = super(BasicLSTMCell, self).get_config()
798    return dict(list(base_config.items()) + list(config.items()))
799
800
801@tf_export(v1=["nn.rnn_cell.LSTMCell"])
802class LSTMCell(LayerRNNCell):
803  """Long short-term memory unit (LSTM) recurrent network cell.
804
805  The default non-peephole implementation is based on:
806
807    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
808
809  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
810  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
811
812  The peephole implementation is based on:
813
814    https://research.google.com/pubs/archive/43905.pdf
815
816  Hasim Sak, Andrew Senior, and Francoise Beaufays.
817  "Long short-term memory recurrent neural network architectures for
818   large scale acoustic modeling." INTERSPEECH, 2014.
819
820  The class uses optional peep-hole connections, optional cell clipping, and
821  an optional projection layer.
822
823  Note that this cell is not optimized for performance. Please use
824  `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
825  `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
826  better performance on CPU.
827  """
828
829  @deprecated(None, "This class is equivalent as tf.keras.layers.LSTMCell,"
830                    " and will be replaced by that in Tensorflow 2.0.")
831  def __init__(self, num_units,
832               use_peepholes=False, cell_clip=None,
833               initializer=None, num_proj=None, proj_clip=None,
834               num_unit_shards=None, num_proj_shards=None,
835               forget_bias=1.0, state_is_tuple=True,
836               activation=None, reuse=None, name=None, dtype=None, **kwargs):
837    """Initialize the parameters for an LSTM cell.
838
839    Args:
840      num_units: int, The number of units in the LSTM cell.
841      use_peepholes: bool, set True to enable diagonal/peephole connections.
842      cell_clip: (optional) A float value, if provided the cell state is clipped
843        by this value prior to the cell output activation.
844      initializer: (optional) The initializer to use for the weight and
845        projection matrices.
846      num_proj: (optional) int, The output dimensionality for the projection
847        matrices.  If None, no projection is performed.
848      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
849        provided, then the projected values are clipped elementwise to within
850        `[-proj_clip, proj_clip]`.
851      num_unit_shards: Deprecated, will be removed by Jan. 2017.
852        Use a variable_scope partitioner instead.
853      num_proj_shards: Deprecated, will be removed by Jan. 2017.
854        Use a variable_scope partitioner instead.
855      forget_bias: Biases of the forget gate are initialized by default to 1
856        in order to reduce the scale of forgetting at the beginning of
857        the training. Must set it manually to `0.0` when restoring from
858        CudnnLSTM trained checkpoints.
859      state_is_tuple: If True, accepted and returned states are 2-tuples of
860        the `c_state` and `m_state`.  If False, they are concatenated
861        along the column axis.  This latter behavior will soon be deprecated.
862      activation: Activation function of the inner states.  Default: `tanh`. It
863        could also be string that is within Keras activation function names.
864      reuse: (optional) Python boolean describing whether to reuse variables
865        in an existing scope.  If not `True`, and the existing scope already has
866        the given variables, an error is raised.
867      name: String, the name of the layer. Layers with the same name will
868        share weights, but to avoid mistakes we require reuse=True in such
869        cases.
870      dtype: Default dtype of the layer (default of `None` means use the type
871        of the first input). Required when `build` is called before `call`.
872      **kwargs: Dict, keyword named properties for common layer attributes, like
873        `trainable` etc when constructing the cell from configs of get_config().
874
875      When restoring from CudnnLSTM-trained checkpoints, use
876      `CudnnCompatibleLSTMCell` instead.
877    """
878    super(LSTMCell, self).__init__(
879        _reuse=reuse, name=name, dtype=dtype, **kwargs)
880    _check_supported_dtypes(self.dtype)
881    if not state_is_tuple:
882      logging.warn("%s: Using a concatenated state is slower and will soon be "
883                   "deprecated.  Use state_is_tuple=True.", self)
884    if num_unit_shards is not None or num_proj_shards is not None:
885      logging.warn(
886          "%s: The num_unit_shards and proj_unit_shards parameters are "
887          "deprecated and will be removed in Jan 2017.  "
888          "Use a variable scope with a partitioner instead.", self)
889    if context.executing_eagerly() and context.num_gpus() > 0:
890      logging.warn("%s: Note that this cell is not optimized for performance. "
891                   "Please use tf.contrib.cudnn_rnn.CudnnLSTM for better "
892                   "performance on GPU.", self)
893
894    # Inputs must be 2-dimensional.
895    self.input_spec = input_spec.InputSpec(ndim=2)
896
897    self._num_units = num_units
898    self._use_peepholes = use_peepholes
899    self._cell_clip = cell_clip
900    self._initializer = initializers.get(initializer)
901    self._num_proj = num_proj
902    self._proj_clip = proj_clip
903    self._num_unit_shards = num_unit_shards
904    self._num_proj_shards = num_proj_shards
905    self._forget_bias = forget_bias
906    self._state_is_tuple = state_is_tuple
907    if activation:
908      self._activation = activations.get(activation)
909    else:
910      self._activation = math_ops.tanh
911
912    if num_proj:
913      self._state_size = (
914          LSTMStateTuple(num_units, num_proj)
915          if state_is_tuple else num_units + num_proj)
916      self._output_size = num_proj
917    else:
918      self._state_size = (
919          LSTMStateTuple(num_units, num_units)
920          if state_is_tuple else 2 * num_units)
921      self._output_size = num_units
922
923  @property
924  def state_size(self):
925    return self._state_size
926
927  @property
928  def output_size(self):
929    return self._output_size
930
931  @tf_utils.shape_type_conversion
932  def build(self, inputs_shape):
933    if inputs_shape[-1] is None:
934      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
935                       % str(inputs_shape))
936    _check_supported_dtypes(self.dtype)
937    input_depth = inputs_shape[-1]
938    h_depth = self._num_units if self._num_proj is None else self._num_proj
939    maybe_partitioner = (
940        partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
941        if self._num_unit_shards is not None
942        else None)
943    self._kernel = self.add_variable(
944        _WEIGHTS_VARIABLE_NAME,
945        shape=[input_depth + h_depth, 4 * self._num_units],
946        initializer=self._initializer,
947        partitioner=maybe_partitioner)
948    if self.dtype is None:
949      initializer = init_ops.zeros_initializer
950    else:
951      initializer = init_ops.zeros_initializer(dtype=self.dtype)
952    self._bias = self.add_variable(
953        _BIAS_VARIABLE_NAME,
954        shape=[4 * self._num_units],
955        initializer=initializer)
956    if self._use_peepholes:
957      self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
958                                         initializer=self._initializer)
959      self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
960                                         initializer=self._initializer)
961      self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
962                                         initializer=self._initializer)
963
964    if self._num_proj is not None:
965      maybe_proj_partitioner = (
966          partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
967          if self._num_proj_shards is not None
968          else None)
969      self._proj_kernel = self.add_variable(
970          "projection/%s" % _WEIGHTS_VARIABLE_NAME,
971          shape=[self._num_units, self._num_proj],
972          initializer=self._initializer,
973          partitioner=maybe_proj_partitioner)
974
975    self.built = True
976
977  def call(self, inputs, state):
978    """Run one step of LSTM.
979
980    Args:
981      inputs: input Tensor, must be 2-D, `[batch, input_size]`.
982      state: if `state_is_tuple` is False, this must be a state Tensor,
983        `2-D, [batch, state_size]`.  If `state_is_tuple` is True, this must be a
984        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
985        `m_state`.
986
987    Returns:
988      A tuple containing:
989
990      - A `2-D, [batch, output_dim]`, Tensor representing the output of the
991        LSTM after reading `inputs` when previous state was `state`.
992        Here output_dim is:
993           num_proj if num_proj was set,
994           num_units otherwise.
995      - Tensor(s) representing the new state of LSTM after reading `inputs` when
996        the previous state was `state`.  Same type and shape(s) as `state`.
997
998    Raises:
999      ValueError: If input size cannot be inferred from inputs via
1000        static shape inference.
1001    """
1002    _check_rnn_cell_input_dtypes([inputs, state])
1003
1004    num_proj = self._num_units if self._num_proj is None else self._num_proj
1005    sigmoid = math_ops.sigmoid
1006
1007    if self._state_is_tuple:
1008      (c_prev, m_prev) = state
1009    else:
1010      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
1011      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
1012
1013    input_size = inputs.get_shape().with_rank(2).dims[1].value
1014    if input_size is None:
1015      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1016
1017    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
1018    lstm_matrix = math_ops.matmul(
1019        array_ops.concat([inputs, m_prev], 1), self._kernel)
1020    lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)
1021
1022    i, j, f, o = array_ops.split(
1023        value=lstm_matrix, num_or_size_splits=4, axis=1)
1024    # Diagonal connections
1025    if self._use_peepholes:
1026      c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
1027           sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
1028    else:
1029      c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
1030           self._activation(j))
1031
1032    if self._cell_clip is not None:
1033      # pylint: disable=invalid-unary-operand-type
1034      c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
1035      # pylint: enable=invalid-unary-operand-type
1036    if self._use_peepholes:
1037      m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
1038    else:
1039      m = sigmoid(o) * self._activation(c)
1040
1041    if self._num_proj is not None:
1042      m = math_ops.matmul(m, self._proj_kernel)
1043
1044      if self._proj_clip is not None:
1045        # pylint: disable=invalid-unary-operand-type
1046        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
1047        # pylint: enable=invalid-unary-operand-type
1048
1049    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
1050                 array_ops.concat([c, m], 1))
1051    return m, new_state
1052
1053  def get_config(self):
1054    config = {
1055        "num_units": self._num_units,
1056        "use_peepholes": self._use_peepholes,
1057        "cell_clip": self._cell_clip,
1058        "initializer": initializers.serialize(self._initializer),
1059        "num_proj": self._num_proj,
1060        "proj_clip": self._proj_clip,
1061        "num_unit_shards": self._num_unit_shards,
1062        "num_proj_shards": self._num_proj_shards,
1063        "forget_bias": self._forget_bias,
1064        "state_is_tuple": self._state_is_tuple,
1065        "activation": activations.serialize(self._activation),
1066        "reuse": self._reuse,
1067    }
1068    base_config = super(LSTMCell, self).get_config()
1069    return dict(list(base_config.items()) + list(config.items()))
1070
1071
1072def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
1073  ix = [0]
1074  def enumerated_fn(*inner_args, **inner_kwargs):
1075    r = map_fn(ix[0], *inner_args, **inner_kwargs)
1076    ix[0] += 1
1077    return r
1078  return nest.map_structure_up_to(shallow_structure,
1079                                  enumerated_fn, *args, **kwargs)
1080
1081
1082def _default_dropout_state_filter_visitor(substate):
1083  if isinstance(substate, LSTMStateTuple):
1084    # Do not perform dropout on the memory state.
1085    return LSTMStateTuple(c=False, h=True)
1086  elif isinstance(substate, tensor_array_ops.TensorArray):
1087    return False
1088  return True
1089
1090
1091class _RNNCellWrapperV1(RNNCell):
1092  """Base class for cells wrappers V1 compatibility.
1093
1094  This class along with `_RNNCellWrapperV2` allows to define cells wrappers that
1095  are compatible with V1 and V2, and defines helper methods for this purpose.
1096  """
1097
1098  def __init__(self, cell):
1099    super(_RNNCellWrapperV1, self).__init__()
1100    self.cell = cell
1101    if isinstance(cell, trackable.Trackable):
1102      self._track_trackable(self.cell, name="cell")
1103
1104  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
1105    """Calls the wrapped cell and performs the wrapping logic.
1106
1107    This method is called from the wrapper's `call` or `__call__` methods.
1108
1109    Args:
1110      inputs: A tensor with wrapped cell's input.
1111      state: A tensor or tuple of tensors with wrapped cell's state.
1112      cell_call_fn: Wrapped cell's method to use for step computation (cell's
1113        `__call__` or 'call' method).
1114      **kwargs: Additional arguments.
1115
1116    Returns:
1117      A pair containing:
1118      - Output: A tensor with cell's output.
1119      - New state: A tensor or tuple of tensors with new wrapped cell's state.
1120    """
1121    raise NotImplementedError
1122
1123  def __call__(self, inputs, state, scope=None):
1124    """Runs the RNN cell step computation.
1125
1126    We assume that the wrapped RNNCell is being built within its `__call__`
1127    method. We directly use the wrapped cell's `__call__` in the overridden
1128    wrapper `__call__` method.
1129
1130    This allows to use the wrapped cell and the non-wrapped cell equivalently
1131    when using `__call__`.
1132
1133    Args:
1134      inputs: A tensor with wrapped cell's input.
1135      state: A tensor or tuple of tensors with wrapped cell's state.
1136      scope: VariableScope for the subgraph created in the wrapped cells'
1137        `__call__`.
1138
1139    Returns:
1140      A pair containing:
1141
1142      - Output: A tensor with cell's output.
1143      - New state: A tensor or tuple of tensors with new wrapped cell's state.
1144    """
1145    return self._call_wrapped_cell(
1146        inputs, state, cell_call_fn=self.cell.__call__, scope=scope)
1147
1148
1149class _RNNCellWrapperV2(keras_layer.AbstractRNNCell):
1150  """Base class for cells wrappers V2 compatibility.
1151
1152  This class along with `_RNNCellWrapperV1` allows to define cells wrappers that
1153  are compatible with V1 and V2, and defines helper methods for this purpose.
1154  """
1155
1156  def __init__(self, cell, *args, **kwargs):
1157    super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
1158    self.cell = cell
1159
1160  def call(self, inputs, state, **kwargs):
1161    """Runs the RNN cell step computation.
1162
1163    When `call` is being used, we assume that the wrapper object has been built,
1164    and therefore the wrapped cells has been built via its `build` method and
1165    its `call` method can be used directly.
1166
1167    This allows to use the wrapped cell and the non-wrapped cell equivalently
1168    when using `call` and `build`.
1169
1170    Args:
1171      inputs: A tensor with wrapped cell's input.
1172      state: A tensor or tuple of tensors with wrapped cell's state.
1173      **kwargs: Additional arguments passed to the wrapped cell's `call`.
1174
1175    Returns:
1176      A pair containing:
1177
1178      - Output: A tensor with cell's output.
1179      - New state: A tensor or tuple of tensors with new wrapped cell's state.
1180    """
1181    return self._call_wrapped_cell(
1182        inputs, state, cell_call_fn=self.cell.call, **kwargs)
1183
1184  def build(self, inputs_shape):
1185    """Builds the wrapped cell."""
1186    self.cell.build(inputs_shape)
1187    self.built = True
1188
1189
1190class DropoutWrapperBase(object):
1191  """Operator adding dropout to inputs and outputs of the given cell."""
1192
1193  def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
1194               state_keep_prob=1.0, variational_recurrent=False,
1195               input_size=None, dtype=None, seed=None,
1196               dropout_state_filter_visitor=None):
1197    """Create a cell with added input, state, and/or output dropout.
1198
1199    If `variational_recurrent` is set to `True` (**NOT** the default behavior),
1200    then the same dropout mask is applied at every step, as described in:
1201
1202    Y. Gal, Z Ghahramani.  "A Theoretically Grounded Application of Dropout in
1203    Recurrent Neural Networks".  https://arxiv.org/abs/1512.05287
1204
1205    Otherwise a different dropout mask is applied at every time step.
1206
1207    Note, by default (unless a custom `dropout_state_filter` is provided),
1208    the memory state (`c` component of any `LSTMStateTuple`) passing through
1209    a `DropoutWrapper` is never modified.  This behavior is described in the
1210    above article.
1211
1212    Args:
1213      cell: an RNNCell, a projection to output_size is added to it.
1214      input_keep_prob: unit Tensor or float between 0 and 1, input keep
1215        probability; if it is constant and 1, no input dropout will be added.
1216      output_keep_prob: unit Tensor or float between 0 and 1, output keep
1217        probability; if it is constant and 1, no output dropout will be added.
1218      state_keep_prob: unit Tensor or float between 0 and 1, output keep
1219        probability; if it is constant and 1, no output dropout will be added.
1220        State dropout is performed on the outgoing states of the cell.
1221        **Note** the state components to which dropout is applied when
1222        `state_keep_prob` is in `(0, 1)` are also determined by
1223        the argument `dropout_state_filter_visitor` (e.g. by default dropout
1224        is never applied to the `c` component of an `LSTMStateTuple`).
1225      variational_recurrent: Python bool.  If `True`, then the same
1226        dropout pattern is applied across all time steps per run call.
1227        If this parameter is set, `input_size` **must** be provided.
1228      input_size: (optional) (possibly nested tuple of) `TensorShape` objects
1229        containing the depth(s) of the input tensors expected to be passed in to
1230        the `DropoutWrapper`.  Required and used **iff**
1231         `variational_recurrent = True` and `input_keep_prob < 1`.
1232      dtype: (optional) The `dtype` of the input, state, and output tensors.
1233        Required and used **iff** `variational_recurrent = True`.
1234      seed: (optional) integer, the randomness seed.
1235      dropout_state_filter_visitor: (optional), default: (see below).  Function
1236        that takes any hierarchical level of the state and returns
1237        a scalar or depth=1 structure of Python booleans describing
1238        which terms in the state should be dropped out.  In addition, if the
1239        function returns `True`, dropout is applied across this sublevel.  If
1240        the function returns `False`, dropout is not applied across this entire
1241        sublevel.
1242        Default behavior: perform dropout on all terms except the memory (`c`)
1243        state of `LSTMCellState` objects, and don't try to apply dropout to
1244        `TensorArray` objects:
1245        ```
1246        def dropout_state_filter_visitor(s):
1247          if isinstance(s, LSTMCellState):
1248            # Never perform dropout on the c state.
1249            return LSTMCellState(c=False, h=True)
1250          elif isinstance(s, TensorArray):
1251            return False
1252          return True
1253        ```
1254
1255    Raises:
1256      TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided
1257        but not `callable`.
1258      ValueError: if any of the keep_probs are not between 0 and 1.
1259    """
1260    super(DropoutWrapperBase, self).__init__(cell)
1261    assert_like_rnncell("cell", cell)
1262
1263    if (dropout_state_filter_visitor is not None
1264        and not callable(dropout_state_filter_visitor)):
1265      raise TypeError("dropout_state_filter_visitor must be callable")
1266    self._dropout_state_filter = (
1267        dropout_state_filter_visitor or _default_dropout_state_filter_visitor)
1268    with ops.name_scope("DropoutWrapperInit"):
1269      def tensor_and_const_value(v):
1270        tensor_value = ops.convert_to_tensor(v)
1271        const_value = tensor_util.constant_value(tensor_value)
1272        return (tensor_value, const_value)
1273      for prob, attr in [(input_keep_prob, "input_keep_prob"),
1274                         (state_keep_prob, "state_keep_prob"),
1275                         (output_keep_prob, "output_keep_prob")]:
1276        tensor_prob, const_prob = tensor_and_const_value(prob)
1277        if const_prob is not None:
1278          if const_prob < 0 or const_prob > 1:
1279            raise ValueError("Parameter %s must be between 0 and 1: %d"
1280                             % (attr, const_prob))
1281          setattr(self, "_%s" % attr, float(const_prob))
1282        else:
1283          setattr(self, "_%s" % attr, tensor_prob)
1284
1285    # Set variational_recurrent, seed before running the code below
1286    self._variational_recurrent = variational_recurrent
1287    self._seed = seed
1288
1289    self._recurrent_input_noise = None
1290    self._recurrent_state_noise = None
1291    self._recurrent_output_noise = None
1292
1293    if variational_recurrent:
1294      if dtype is None:
1295        raise ValueError(
1296            "When variational_recurrent=True, dtype must be provided")
1297
1298      def convert_to_batch_shape(s):
1299        # Prepend a 1 for the batch dimension; for recurrent
1300        # variational dropout we use the same dropout mask for all
1301        # batch elements.
1302        return array_ops.concat(
1303            ([1], tensor_shape.TensorShape(s).as_list()), 0)
1304
1305      def batch_noise(s, inner_seed):
1306        shape = convert_to_batch_shape(s)
1307        return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype)
1308
1309      if (not isinstance(self._input_keep_prob, numbers.Real) or
1310          self._input_keep_prob < 1.0):
1311        if input_size is None:
1312          raise ValueError(
1313              "When variational_recurrent=True and input_keep_prob < 1.0 or "
1314              "is unknown, input_size must be provided")
1315        self._recurrent_input_noise = _enumerated_map_structure_up_to(
1316            input_size,
1317            lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)),
1318            input_size)
1319      self._recurrent_state_noise = _enumerated_map_structure_up_to(
1320          cell.state_size,
1321          lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)),
1322          cell.state_size)
1323      self._recurrent_output_noise = _enumerated_map_structure_up_to(
1324          cell.output_size,
1325          lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)),
1326          cell.output_size)
1327
1328  def _gen_seed(self, salt_prefix, index):
1329    if self._seed is None:
1330      return None
1331    salt = "%s_%d" % (salt_prefix, index)
1332    string = (str(self._seed) + salt).encode("utf-8")
1333    return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
1334
1335  @property
1336  def wrapped_cell(self):
1337    return self.cell
1338
1339  @property
1340  def state_size(self):
1341    return self.cell.state_size
1342
1343  @property
1344  def output_size(self):
1345    return self.cell.output_size
1346
1347  def zero_state(self, batch_size, dtype):
1348    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1349      return self.cell.zero_state(batch_size, dtype)
1350
1351  def _variational_recurrent_dropout_value(
1352      self, index, value, noise, keep_prob):
1353    """Performs dropout given the pre-calculated noise tensor."""
1354    # uniform [keep_prob, 1.0 + keep_prob)
1355    random_tensor = keep_prob + noise
1356
1357    # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
1358    binary_tensor = math_ops.floor(random_tensor)
1359    ret = math_ops.div(value, keep_prob) * binary_tensor
1360    ret.set_shape(value.get_shape())
1361    return ret
1362
1363  def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob,
1364               shallow_filtered_substructure=None):
1365    """Decides whether to perform standard dropout or recurrent dropout."""
1366
1367    if shallow_filtered_substructure is None:
1368      # Put something so we traverse the entire structure; inside the
1369      # dropout function we check to see if leafs of this are bool or not.
1370      shallow_filtered_substructure = values
1371
1372    if not self._variational_recurrent:
1373      def dropout(i, do_dropout, v):
1374        if not isinstance(do_dropout, bool) or do_dropout:
1375          return nn_ops.dropout(
1376              v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i))
1377        else:
1378          return v
1379      return _enumerated_map_structure_up_to(
1380          shallow_filtered_substructure, dropout,
1381          *[shallow_filtered_substructure, values])
1382    else:
1383      def dropout(i, do_dropout, v, n):
1384        if not isinstance(do_dropout, bool) or do_dropout:
1385          return self._variational_recurrent_dropout_value(i, v, n, keep_prob)
1386        else:
1387          return v
1388      return _enumerated_map_structure_up_to(
1389          shallow_filtered_substructure, dropout,
1390          *[shallow_filtered_substructure, values, recurrent_noise])
1391
1392  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
1393    """Runs the wrapped cell and applies dropout.
1394
1395    Args:
1396      inputs: A tensor with wrapped cell's input.
1397      state: A tensor or tuple of tensors with wrapped cell's state.
1398      cell_call_fn: Wrapped cell's method to use for step computation (cell's
1399        `__call__` or 'call' method).
1400      **kwargs: Additional arguments.
1401
1402    Returns:
1403      A pair containing:
1404
1405      - Output: A tensor with cell's output.
1406      - New state: A tensor or tuple of tensors with new wrapped cell's state.
1407    """
1408    def _should_dropout(p):
1409      return (not isinstance(p, float)) or p < 1
1410
1411    if _should_dropout(self._input_keep_prob):
1412      inputs = self._dropout(inputs, "input",
1413                             self._recurrent_input_noise,
1414                             self._input_keep_prob)
1415    output, new_state = cell_call_fn(inputs, state, **kwargs)
1416    if _should_dropout(self._state_keep_prob):
1417      # Identify which subsets of the state to perform dropout on and
1418      # which ones to keep.
1419      shallow_filtered_substructure = nest.get_traverse_shallow_structure(
1420          self._dropout_state_filter, new_state)
1421      new_state = self._dropout(new_state, "state",
1422                                self._recurrent_state_noise,
1423                                self._state_keep_prob,
1424                                shallow_filtered_substructure)
1425    if _should_dropout(self._output_keep_prob):
1426      output = self._dropout(output, "output",
1427                             self._recurrent_output_noise,
1428                             self._output_keep_prob)
1429    return output, new_state
1430
1431
1432@tf_export(v1=["nn.rnn_cell.DropoutWrapper"])
1433class DropoutWrapper(DropoutWrapperBase, _RNNCellWrapperV1):
1434  """Operator adding dropout to inputs and outputs of the given cell."""
1435
1436  def __init__(self, *args, **kwargs):
1437    super(DropoutWrapper, self).__init__(*args, **kwargs)
1438
1439  __init__.__doc__ = DropoutWrapperBase.__init__.__doc__
1440
1441
1442@tf_export("nn.RNNCellDropoutWrapper", v1=[])
1443class DropoutWrapperV2(DropoutWrapperBase, _RNNCellWrapperV2):
1444  """Operator adding dropout to inputs and outputs of the given cell."""
1445
1446  def __init__(self, *args, **kwargs):
1447    super(DropoutWrapperV2, self).__init__(*args, **kwargs)
1448
1449  __init__.__doc__ = DropoutWrapperBase.__init__.__doc__
1450
1451
1452class ResidualWrapperBase(object):
1453  """RNNCell wrapper that ensures cell inputs are added to the outputs."""
1454
1455  def __init__(self, cell, residual_fn=None):
1456    """Constructs a `ResidualWrapper` for `cell`.
1457
1458    Args:
1459      cell: An instance of `RNNCell`.
1460      residual_fn: (Optional) The function to map raw cell inputs and raw cell
1461        outputs to the actual cell outputs of the residual network.
1462        Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
1463        and outputs.
1464    """
1465    super(ResidualWrapperBase, self).__init__(cell)
1466    self._residual_fn = residual_fn
1467
1468  @property
1469  def state_size(self):
1470    return self.cell.state_size
1471
1472  @property
1473  def output_size(self):
1474    return self.cell.output_size
1475
1476  def zero_state(self, batch_size, dtype):
1477    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1478      return self.cell.zero_state(batch_size, dtype)
1479
1480  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
1481    """Run the cell and then apply the residual_fn on its inputs to its outputs.
1482
1483    Args:
1484      inputs: cell inputs.
1485      state: cell state.
1486      cell_call_fn: Wrapped cell's method to use for step computation (cell's
1487        `__call__` or 'call' method).
1488      **kwargs: Additional arguments passed to the wrapped cell's `call`.
1489
1490    Returns:
1491      Tuple of cell outputs and new state.
1492
1493    Raises:
1494      TypeError: If cell inputs and outputs have different structure (type).
1495      ValueError: If cell inputs and outputs have different structure (value).
1496    """
1497    outputs, new_state = cell_call_fn(inputs, state, **kwargs)
1498    # Ensure shapes match
1499    def assert_shape_match(inp, out):
1500      inp.get_shape().assert_is_compatible_with(out.get_shape())
1501    def default_residual_fn(inputs, outputs):
1502      nest.assert_same_structure(inputs, outputs)
1503      nest.map_structure(assert_shape_match, inputs, outputs)
1504      return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
1505    res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
1506    return (res_outputs, new_state)
1507
1508
1509@tf_export(v1=["nn.rnn_cell.ResidualWrapper"])
1510class ResidualWrapper(ResidualWrapperBase, _RNNCellWrapperV1):
1511  """RNNCell wrapper that ensures cell inputs are added to the outputs."""
1512
1513  def __init__(self, *args, **kwargs):
1514    super(ResidualWrapper, self).__init__(*args, **kwargs)
1515
1516  __init__.__doc__ = ResidualWrapperBase.__init__.__doc__
1517
1518
1519@tf_export("nn.RNNCellResidualWrapper", v1=[])
1520class ResidualWrapperV2(ResidualWrapperBase, _RNNCellWrapperV2):
1521  """RNNCell wrapper that ensures cell inputs are added to the outputs."""
1522
1523  def __init__(self, *args, **kwargs):
1524    super(ResidualWrapperV2, self).__init__(*args, **kwargs)
1525
1526  __init__.__doc__ = ResidualWrapperBase.__init__.__doc__
1527
1528
1529class DeviceWrapperBase(object):
1530  """Operator that ensures an RNNCell runs on a particular device."""
1531
1532  def __init__(self, cell, device):
1533    """Construct a `DeviceWrapper` for `cell` with device `device`.
1534
1535    Ensures the wrapped `cell` is called with `tf.device(device)`.
1536
1537    Args:
1538      cell: An instance of `RNNCell`.
1539      device: A device string or function, for passing to `tf.device`.
1540    """
1541    super(DeviceWrapperBase, self).__init__(cell)
1542    self._device = device
1543
1544  @property
1545  def state_size(self):
1546    return self.cell.state_size
1547
1548  @property
1549  def output_size(self):
1550    return self.cell.output_size
1551
1552  def zero_state(self, batch_size, dtype):
1553    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1554      with ops.device(self._device):
1555        return self.cell.zero_state(batch_size, dtype)
1556
1557  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
1558    """Run the cell on specified device."""
1559    with ops.device(self._device):
1560      return cell_call_fn(inputs, state, **kwargs)
1561
1562
1563@tf_export(v1=["nn.rnn_cell.DeviceWrapper"])
1564class DeviceWrapper(DeviceWrapperBase, _RNNCellWrapperV1):
1565
1566  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
1567    super(DeviceWrapper, self).__init__(*args, **kwargs)
1568
1569  __init__.__doc__ = DeviceWrapperBase.__init__.__doc__
1570
1571
1572@tf_export("nn.RNNCellDeviceWrapper", v1=[])
1573class DeviceWrapperV2(DeviceWrapperBase, _RNNCellWrapperV2):
1574  """Operator that ensures an RNNCell runs on a particular device."""
1575
1576  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
1577    super(DeviceWrapperV2, self).__init__(*args, **kwargs)
1578
1579  __init__.__doc__ = DeviceWrapperBase.__init__.__doc__
1580
1581
1582@tf_export(v1=["nn.rnn_cell.MultiRNNCell"])
1583class MultiRNNCell(RNNCell):
1584  """RNN cell composed sequentially of multiple simple cells.
1585
1586  Example:
1587
1588  ```python
1589  num_units = [128, 64]
1590  cells = [BasicLSTMCell(num_units=n) for n in num_units]
1591  stacked_rnn_cell = MultiRNNCell(cells)
1592  ```
1593  """
1594
1595  @deprecated(None, "This class is equivalent as "
1596                    "tf.keras.layers.StackedRNNCells, and will be replaced by "
1597                    "that in Tensorflow 2.0.")
1598  def __init__(self, cells, state_is_tuple=True):
1599    """Create a RNN cell composed sequentially of a number of RNNCells.
1600
1601    Args:
1602      cells: list of RNNCells that will be composed in this order.
1603      state_is_tuple: If True, accepted and returned states are n-tuples, where
1604        `n = len(cells)`.  If False, the states are all
1605        concatenated along the column axis.  This latter behavior will soon be
1606        deprecated.
1607
1608    Raises:
1609      ValueError: if cells is empty (not allowed), or at least one of the cells
1610        returns a state tuple but the flag `state_is_tuple` is `False`.
1611    """
1612    super(MultiRNNCell, self).__init__()
1613    if not cells:
1614      raise ValueError("Must specify at least one cell for MultiRNNCell.")
1615    if not nest.is_sequence(cells):
1616      raise TypeError(
1617          "cells must be a list or tuple, but saw: %s." % cells)
1618
1619    if len(set([id(cell) for cell in cells])) < len(cells):
1620      logging.log_first_n(logging.WARN,
1621                          "At least two cells provided to MultiRNNCell "
1622                          "are the same object and will share weights.", 1)
1623
1624    self._cells = cells
1625    for cell_number, cell in enumerate(self._cells):
1626      # Add Trackable dependencies on these cells so their variables get
1627      # saved with this object when using object-based saving.
1628      if isinstance(cell, trackable.Trackable):
1629        # TODO(allenl): Track down non-Trackable callers.
1630        self._track_trackable(cell, name="cell-%d" % (cell_number,))
1631    self._state_is_tuple = state_is_tuple
1632    if not state_is_tuple:
1633      if any(nest.is_sequence(c.state_size) for c in self._cells):
1634        raise ValueError("Some cells return tuples of states, but the flag "
1635                         "state_is_tuple is not set.  State sizes are: %s"
1636                         % str([c.state_size for c in self._cells]))
1637
1638  @property
1639  def state_size(self):
1640    if self._state_is_tuple:
1641      return tuple(cell.state_size for cell in self._cells)
1642    else:
1643      return sum(cell.state_size for cell in self._cells)
1644
1645  @property
1646  def output_size(self):
1647    return self._cells[-1].output_size
1648
1649  def zero_state(self, batch_size, dtype):
1650    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1651      if self._state_is_tuple:
1652        return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells)
1653      else:
1654        # We know here that state_size of each cell is not a tuple and
1655        # presumably does not contain TensorArrays or anything else fancy
1656        return super(MultiRNNCell, self).zero_state(batch_size, dtype)
1657
1658  @property
1659  def trainable_weights(self):
1660    if not self.trainable:
1661      return []
1662    weights = []
1663    for cell in self._cells:
1664      if isinstance(cell, base_layer.Layer):
1665        weights += cell.trainable_weights
1666    return weights
1667
1668  @property
1669  def non_trainable_weights(self):
1670    weights = []
1671    for cell in self._cells:
1672      if isinstance(cell, base_layer.Layer):
1673        weights += cell.non_trainable_weights
1674    if not self.trainable:
1675      trainable_weights = []
1676      for cell in self._cells:
1677        if isinstance(cell, base_layer.Layer):
1678          trainable_weights += cell.trainable_weights
1679      return trainable_weights + weights
1680    return weights
1681
1682  def call(self, inputs, state):
1683    """Run this multi-layer cell on inputs, starting from state."""
1684    cur_state_pos = 0
1685    cur_inp = inputs
1686    new_states = []
1687    for i, cell in enumerate(self._cells):
1688      with vs.variable_scope("cell_%d" % i):
1689        if self._state_is_tuple:
1690          if not nest.is_sequence(state):
1691            raise ValueError(
1692                "Expected state to be a tuple of length %d, but received: %s" %
1693                (len(self.state_size), state))
1694          cur_state = state[i]
1695        else:
1696          cur_state = array_ops.slice(state, [0, cur_state_pos],
1697                                      [-1, cell.state_size])
1698          cur_state_pos += cell.state_size
1699        cur_inp, new_state = cell(cur_inp, cur_state)
1700        new_states.append(new_state)
1701
1702    new_states = (tuple(new_states) if self._state_is_tuple else
1703                  array_ops.concat(new_states, 1))
1704
1705    return cur_inp, new_states
1706
1707
1708def _check_rnn_cell_input_dtypes(inputs):
1709  """Check whether the input tensors are with supported dtypes.
1710
1711  Default RNN cells only support floats and complex as its dtypes since the
1712  activation function (tanh and sigmoid) only allow those types. This function
1713  will throw a proper error message if the inputs is not in a supported type.
1714
1715  Args:
1716    inputs: tensor or nested structure of tensors that are feed to RNN cell as
1717      input or state.
1718
1719  Raises:
1720    ValueError: if any of the input tensor are not having dtypes of float or
1721      complex.
1722  """
1723  for t in nest.flatten(inputs):
1724    _check_supported_dtypes(t.dtype)
1725
1726
1727def _check_supported_dtypes(dtype):
1728  if dtype is None:
1729    return
1730  dtype = dtypes.as_dtype(dtype)
1731  if not (dtype.is_floating or dtype.is_complex):
1732    raise ValueError("RNN cell only supports floating point inputs, "
1733                     "but saw dtype: %s" % dtype)
1734