• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for constructing RNN Cells."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import math
22
23from tensorflow.contrib.compiler import jit
24from tensorflow.contrib.layers.python.layers import layers
25from tensorflow.contrib.rnn.python.ops import core_rnn_cell
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import op_def_registry
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.layers import base as base_layer
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import clip_ops
33from tensorflow.python.ops import init_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn_impl  # pylint: disable=unused-import
36from tensorflow.python.ops import nn_ops
37from tensorflow.python.ops import partitioned_variables  # pylint: disable=unused-import
38from tensorflow.python.ops import random_ops
39from tensorflow.python.ops import rnn_cell_impl
40from tensorflow.python.ops import variable_scope as vs
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.util import nest
43
44
45def _get_concat_variable(name, shape, dtype, num_shards):
46  """Get a sharded variable concatenated into one tensor."""
47  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
48  if len(sharded_variable) == 1:
49    return sharded_variable[0]
50
51  concat_name = name + "/concat"
52  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
53  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
54    if value.name == concat_full_name:
55      return value
56
57  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
58  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
59  return concat_variable
60
61
62def _get_sharded_variable(name, shape, dtype, num_shards):
63  """Get a list of sharded variables with the given dtype."""
64  if num_shards > shape[0]:
65    raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
66                                                                   num_shards))
67  unit_shard_size = int(math.floor(shape[0] / num_shards))
68  remaining_rows = shape[0] - unit_shard_size * num_shards
69
70  shards = []
71  for i in range(num_shards):
72    current_size = unit_shard_size
73    if i < remaining_rows:
74      current_size += 1
75    shards.append(
76        vs.get_variable(
77            name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
78  return shards
79
80
81def _norm(g, b, inp, scope):
82  shape = inp.get_shape()[-1:]
83  gamma_init = init_ops.constant_initializer(g)
84  beta_init = init_ops.constant_initializer(b)
85  with vs.variable_scope(scope):
86    # Initialize beta and gamma for use by layer_norm.
87    vs.get_variable("gamma", shape=shape, initializer=gamma_init)
88    vs.get_variable("beta", shape=shape, initializer=beta_init)
89  normalized = layers.layer_norm(inp, reuse=True, scope=scope)
90  return normalized
91
92
93class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
94  """Long short-term memory unit (LSTM) recurrent network cell.
95
96  The default non-peephole implementation is based on:
97
98    http://www.bioinf.jku.at/publications/older/2604.pdf
99
100  S. Hochreiter and J. Schmidhuber.
101  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
102
103  The peephole implementation is based on:
104
105    https://research.google.com/pubs/archive/43905.pdf
106
107  Hasim Sak, Andrew Senior, and Francoise Beaufays.
108  "Long short-term memory recurrent neural network architectures for
109   large scale acoustic modeling." INTERSPEECH, 2014.
110
111  The coupling of input and forget gate is based on:
112
113    http://arxiv.org/pdf/1503.04069.pdf
114
115  Greff et al. "LSTM: A Search Space Odyssey"
116
117  The class uses optional peep-hole connections, and an optional projection
118  layer.
119  Layer normalization implementation is based on:
120
121    https://arxiv.org/abs/1607.06450.
122
123  "Layer Normalization"
124  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
125
126  and is applied before the internal nonlinearities.
127
128  """
129
130  def __init__(self,
131               num_units,
132               use_peepholes=False,
133               initializer=None,
134               num_proj=None,
135               proj_clip=None,
136               num_unit_shards=1,
137               num_proj_shards=1,
138               forget_bias=1.0,
139               state_is_tuple=True,
140               activation=math_ops.tanh,
141               reuse=None,
142               layer_norm=False,
143               norm_gain=1.0,
144               norm_shift=0.0):
145    """Initialize the parameters for an LSTM cell.
146
147    Args:
148      num_units: int, The number of units in the LSTM cell
149      use_peepholes: bool, set True to enable diagonal/peephole connections.
150      initializer: (optional) The initializer to use for the weight and
151        projection matrices.
152      num_proj: (optional) int, The output dimensionality for the projection
153        matrices.  If None, no projection is performed.
154      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
155      provided, then the projected values are clipped elementwise to within
156      `[-proj_clip, proj_clip]`.
157      num_unit_shards: How to split the weight matrix.  If >1, the weight
158        matrix is stored across num_unit_shards.
159      num_proj_shards: How to split the projection matrix.  If >1, the
160        projection matrix is stored across num_proj_shards.
161      forget_bias: Biases of the forget gate are initialized by default to 1
162        in order to reduce the scale of forgetting at the beginning of
163        the training.
164      state_is_tuple: If True, accepted and returned states are 2-tuples of
165        the `c_state` and `m_state`.  By default (False), they are concatenated
166        along the column axis.  This default behavior will soon be deprecated.
167      activation: Activation function of the inner states.
168      reuse: (optional) Python boolean describing whether to reuse variables
169        in an existing scope.  If not `True`, and the existing scope already has
170        the given variables, an error is raised.
171      layer_norm: If `True`, layer normalization will be applied.
172      norm_gain: float, The layer normalization gain initial value. If
173        `layer_norm` has been set to `False`, this argument will be ignored.
174      norm_shift: float, The layer normalization shift initial value. If
175        `layer_norm` has been set to `False`, this argument will be ignored.
176    """
177    super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
178    if not state_is_tuple:
179      logging.warn("%s: Using a concatenated state is slower and will soon be "
180                   "deprecated.  Use state_is_tuple=True.", self)
181    self._num_units = num_units
182    self._use_peepholes = use_peepholes
183    self._initializer = initializer
184    self._num_proj = num_proj
185    self._proj_clip = proj_clip
186    self._num_unit_shards = num_unit_shards
187    self._num_proj_shards = num_proj_shards
188    self._forget_bias = forget_bias
189    self._state_is_tuple = state_is_tuple
190    self._activation = activation
191    self._reuse = reuse
192    self._layer_norm = layer_norm
193    self._norm_gain = norm_gain
194    self._norm_shift = norm_shift
195
196    if num_proj:
197      self._state_size = (
198          rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
199          if state_is_tuple else num_units + num_proj)
200      self._output_size = num_proj
201    else:
202      self._state_size = (
203          rnn_cell_impl.LSTMStateTuple(num_units, num_units)
204          if state_is_tuple else 2 * num_units)
205      self._output_size = num_units
206
207  @property
208  def state_size(self):
209    return self._state_size
210
211  @property
212  def output_size(self):
213    return self._output_size
214
215  def call(self, inputs, state):
216    """Run one step of LSTM.
217
218    Args:
219      inputs: input Tensor, 2D, batch x num_units.
220      state: if `state_is_tuple` is False, this must be a state Tensor,
221        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
222        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
223        `m_state`.
224
225    Returns:
226      A tuple containing:
227      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
228        LSTM after reading `inputs` when previous state was `state`.
229        Here output_dim is:
230           num_proj if num_proj was set,
231           num_units otherwise.
232      - Tensor(s) representing the new state of LSTM after reading `inputs` when
233        the previous state was `state`.  Same type and shape(s) as `state`.
234
235    Raises:
236      ValueError: If input size cannot be inferred from inputs via
237        static shape inference.
238    """
239    sigmoid = math_ops.sigmoid
240
241    num_proj = self._num_units if self._num_proj is None else self._num_proj
242
243    if self._state_is_tuple:
244      (c_prev, m_prev) = state
245    else:
246      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
247      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
248
249    dtype = inputs.dtype
250    input_size = inputs.get_shape().with_rank(2)[1]
251    if input_size.value is None:
252      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
253    concat_w = _get_concat_variable(
254        "W", [input_size.value + num_proj, 3 * self._num_units], dtype,
255        self._num_unit_shards)
256
257    b = vs.get_variable(
258        "B",
259        shape=[3 * self._num_units],
260        initializer=init_ops.zeros_initializer(),
261        dtype=dtype)
262
263    # j = new_input, f = forget_gate, o = output_gate
264    cell_inputs = array_ops.concat([inputs, m_prev], 1)
265    lstm_matrix = math_ops.matmul(cell_inputs, concat_w)
266
267    # If layer nomalization is applied, do not add bias
268    if not self._layer_norm:
269      lstm_matrix = nn_ops.bias_add(lstm_matrix, b)
270
271    j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
272
273    # Apply layer normalization
274    if self._layer_norm:
275      j = _norm(self._norm_gain, self._norm_shift, j, "transform")
276      f = _norm(self._norm_gain, self._norm_shift, f, "forget")
277      o = _norm(self._norm_gain, self._norm_shift, o, "output")
278
279    # Diagonal connections
280    if self._use_peepholes:
281      w_f_diag = vs.get_variable(
282          "W_F_diag", shape=[self._num_units], dtype=dtype)
283      w_o_diag = vs.get_variable(
284          "W_O_diag", shape=[self._num_units], dtype=dtype)
285
286    if self._use_peepholes:
287      f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
288    else:
289      f_act = sigmoid(f + self._forget_bias)
290    c = (f_act * c_prev + (1 - f_act) * self._activation(j))
291
292    # Apply layer normalization
293    if self._layer_norm:
294      c = _norm(self._norm_gain, self._norm_shift, c, "state")
295
296    if self._use_peepholes:
297      m = sigmoid(o + w_o_diag * c) * self._activation(c)
298    else:
299      m = sigmoid(o) * self._activation(c)
300
301    if self._num_proj is not None:
302      concat_w_proj = _get_concat_variable("W_P",
303                                           [self._num_units, self._num_proj],
304                                           dtype, self._num_proj_shards)
305
306      m = math_ops.matmul(m, concat_w_proj)
307      if self._proj_clip is not None:
308        # pylint: disable=invalid-unary-operand-type
309        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
310        # pylint: enable=invalid-unary-operand-type
311
312    new_state = (
313        rnn_cell_impl.LSTMStateTuple(c, m)
314        if self._state_is_tuple else array_ops.concat([c, m], 1))
315    return m, new_state
316
317
318class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
319  """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
320
321  This implementation is based on:
322
323    Tara N. Sainath and Bo Li
324    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
325    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
326
327  It uses peep-hole connections and optional cell clipping.
328  """
329
330  def __init__(self,
331               num_units,
332               use_peepholes=False,
333               cell_clip=None,
334               initializer=None,
335               num_unit_shards=1,
336               forget_bias=1.0,
337               feature_size=None,
338               frequency_skip=1,
339               reuse=None):
340    """Initialize the parameters for an LSTM cell.
341
342    Args:
343      num_units: int, The number of units in the LSTM cell
344      use_peepholes: bool, set True to enable diagonal/peephole connections.
345      cell_clip: (optional) A float value, if provided the cell state is clipped
346        by this value prior to the cell output activation.
347      initializer: (optional) The initializer to use for the weight and
348        projection matrices.
349      num_unit_shards: int, How to split the weight matrix.  If >1, the weight
350        matrix is stored across num_unit_shards.
351      forget_bias: float, Biases of the forget gate are initialized by default
352        to 1 in order to reduce the scale of forgetting at the beginning
353        of the training.
354      feature_size: int, The size of the input feature the LSTM spans over.
355      frequency_skip: int, The amount the LSTM filter is shifted by in
356        frequency.
357      reuse: (optional) Python boolean describing whether to reuse variables
358        in an existing scope.  If not `True`, and the existing scope already has
359        the given variables, an error is raised.
360    """
361    super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
362    self._num_units = num_units
363    self._use_peepholes = use_peepholes
364    self._cell_clip = cell_clip
365    self._initializer = initializer
366    self._num_unit_shards = num_unit_shards
367    self._forget_bias = forget_bias
368    self._feature_size = feature_size
369    self._frequency_skip = frequency_skip
370    self._state_size = 2 * num_units
371    self._output_size = num_units
372    self._reuse = reuse
373
374  @property
375  def output_size(self):
376    return self._output_size
377
378  @property
379  def state_size(self):
380    return self._state_size
381
382  def call(self, inputs, state):
383    """Run one step of LSTM.
384
385    Args:
386      inputs: input Tensor, 2D, batch x num_units.
387      state: state Tensor, 2D, batch x state_size.
388
389    Returns:
390      A tuple containing:
391      - A 2D, batch x output_dim, Tensor representing the output of the LSTM
392        after reading "inputs" when previous state was "state".
393        Here output_dim is num_units.
394      - A 2D, batch x state_size, Tensor representing the new state of LSTM
395        after reading "inputs" when previous state was "state".
396    Raises:
397      ValueError: if an input_size was specified and the provided inputs have
398        a different dimension.
399    """
400    sigmoid = math_ops.sigmoid
401    tanh = math_ops.tanh
402
403    freq_inputs = self._make_tf_features(inputs)
404    dtype = inputs.dtype
405    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
406
407    concat_w = _get_concat_variable(
408        "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
409        dtype, self._num_unit_shards)
410
411    b = vs.get_variable(
412        "B",
413        shape=[4 * self._num_units],
414        initializer=init_ops.zeros_initializer(),
415        dtype=dtype)
416
417    # Diagonal connections
418    if self._use_peepholes:
419      w_f_diag = vs.get_variable(
420          "W_F_diag", shape=[self._num_units], dtype=dtype)
421      w_i_diag = vs.get_variable(
422          "W_I_diag", shape=[self._num_units], dtype=dtype)
423      w_o_diag = vs.get_variable(
424          "W_O_diag", shape=[self._num_units], dtype=dtype)
425
426    # initialize the first freq state to be zero
427    m_prev_freq = array_ops.zeros(
428        [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
429        dtype)
430    for fq in range(len(freq_inputs)):
431      c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
432                               [-1, self._num_units])
433      m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
434                               [-1, self._num_units])
435      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
436      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
437      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
438      i, j, f, o = array_ops.split(
439          value=lstm_matrix, num_or_size_splits=4, axis=1)
440
441      if self._use_peepholes:
442        c = (
443            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
444            sigmoid(i + w_i_diag * c_prev) * tanh(j))
445      else:
446        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
447
448      if self._cell_clip is not None:
449        # pylint: disable=invalid-unary-operand-type
450        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
451        # pylint: enable=invalid-unary-operand-type
452
453      if self._use_peepholes:
454        m = sigmoid(o + w_o_diag * c) * tanh(c)
455      else:
456        m = sigmoid(o) * tanh(c)
457      m_prev_freq = m
458      if fq == 0:
459        state_out = array_ops.concat([c, m], 1)
460        m_out = m
461      else:
462        state_out = array_ops.concat([state_out, c, m], 1)
463        m_out = array_ops.concat([m_out, m], 1)
464    return m_out, state_out
465
466  def _make_tf_features(self, input_feat):
467    """Make the frequency features.
468
469    Args:
470      input_feat: input Tensor, 2D, batch x num_units.
471
472    Returns:
473      A list of frequency features, with each element containing:
474      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
475        for that frequency index. Here output_dim is feature_size.
476    Raises:
477      ValueError: if input_size cannot be inferred from static shape inference.
478    """
479    input_size = input_feat.get_shape().with_rank(2)[-1].value
480    if input_size is None:
481      raise ValueError("Cannot infer input_size from static shape inference.")
482    num_feats = int(
483        (input_size - self._feature_size) / (self._frequency_skip)) + 1
484    freq_inputs = []
485    for f in range(num_feats):
486      cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
487                                  [-1, self._feature_size])
488      freq_inputs.append(cur_input)
489    return freq_inputs
490
491
492class GridLSTMCell(rnn_cell_impl.RNNCell):
493  """Grid Long short-term memory unit (LSTM) recurrent network cell.
494
495  The default is based on:
496    Nal Kalchbrenner, Ivo Danihelka and Alex Graves
497    "Grid Long Short-Term Memory," Proc. ICLR 2016.
498    http://arxiv.org/abs/1507.01526
499
500  When peephole connections are used, the implementation is based on:
501    Tara N. Sainath and Bo Li
502    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
503    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
504
505  The code uses optional peephole connections, shared_weights and cell clipping.
506  """
507
508  def __init__(self,
509               num_units,
510               use_peepholes=False,
511               share_time_frequency_weights=False,
512               cell_clip=None,
513               initializer=None,
514               num_unit_shards=1,
515               forget_bias=1.0,
516               feature_size=None,
517               frequency_skip=None,
518               num_frequency_blocks=None,
519               start_freqindex_list=None,
520               end_freqindex_list=None,
521               couple_input_forget_gates=False,
522               state_is_tuple=True,
523               reuse=None):
524    """Initialize the parameters for an LSTM cell.
525
526    Args:
527      num_units: int, The number of units in the LSTM cell
528      use_peepholes: (optional) bool, default False. Set True to enable
529        diagonal/peephole connections.
530      share_time_frequency_weights: (optional) bool, default False. Set True to
531        enable shared cell weights between time and frequency LSTMs.
532      cell_clip: (optional) A float value, default None, if provided the cell
533        state is clipped by this value prior to the cell output activation.
534      initializer: (optional) The initializer to use for the weight and
535        projection matrices, default None.
536      num_unit_shards: (optional) int, default 1, How to split the weight
537        matrix. If > 1,the weight matrix is stored across num_unit_shards.
538      forget_bias: (optional) float, default 1.0, The initial bias of the
539        forget gates, used to reduce the scale of forgetting at the beginning
540        of the training.
541      feature_size: (optional) int, default None, The size of the input feature
542        the LSTM spans over.
543      frequency_skip: (optional) int, default None, The amount the LSTM filter
544        is shifted by in frequency.
545      num_frequency_blocks: [required] A list of frequency blocks needed to
546        cover the whole input feature splitting defined by start_freqindex_list
547        and end_freqindex_list.
548      start_freqindex_list: [optional], list of ints, default None,  The
549        starting frequency index for each frequency block.
550      end_freqindex_list: [optional], list of ints, default None. The ending
551        frequency index for each frequency block.
552      couple_input_forget_gates: (optional) bool, default False, Whether to
553        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
554        model parameters and computation cost.
555      state_is_tuple: If True, accepted and returned states are 2-tuples of
556        the `c_state` and `m_state`.  By default (False), they are concatenated
557        along the column axis.  This default behavior will soon be deprecated.
558      reuse: (optional) Python boolean describing whether to reuse variables
559        in an existing scope.  If not `True`, and the existing scope already has
560        the given variables, an error is raised.
561    Raises:
562      ValueError: if the num_frequency_blocks list is not specified
563    """
564    super(GridLSTMCell, self).__init__(_reuse=reuse)
565    if not state_is_tuple:
566      logging.warn("%s: Using a concatenated state is slower and will soon be "
567                   "deprecated.  Use state_is_tuple=True.", self)
568    self._num_units = num_units
569    self._use_peepholes = use_peepholes
570    self._share_time_frequency_weights = share_time_frequency_weights
571    self._couple_input_forget_gates = couple_input_forget_gates
572    self._state_is_tuple = state_is_tuple
573    self._cell_clip = cell_clip
574    self._initializer = initializer
575    self._num_unit_shards = num_unit_shards
576    self._forget_bias = forget_bias
577    self._feature_size = feature_size
578    self._frequency_skip = frequency_skip
579    self._start_freqindex_list = start_freqindex_list
580    self._end_freqindex_list = end_freqindex_list
581    self._num_frequency_blocks = num_frequency_blocks
582    self._total_blocks = 0
583    self._reuse = reuse
584    if self._num_frequency_blocks is None:
585      raise ValueError("Must specify num_frequency_blocks")
586
587    for block_index in range(len(self._num_frequency_blocks)):
588      self._total_blocks += int(self._num_frequency_blocks[block_index])
589    if state_is_tuple:
590      state_names = ""
591      for block_index in range(len(self._num_frequency_blocks)):
592        for freq_index in range(self._num_frequency_blocks[block_index]):
593          name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
594          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
595      self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
596                                                      state_names.strip(","))
597      self._state_size = self._state_tuple_type(*(
598          [num_units, num_units] * self._total_blocks))
599    else:
600      self._state_tuple_type = None
601      self._state_size = num_units * self._total_blocks * 2
602    self._output_size = num_units * self._total_blocks * 2
603
604  @property
605  def output_size(self):
606    return self._output_size
607
608  @property
609  def state_size(self):
610    return self._state_size
611
612  @property
613  def state_tuple_type(self):
614    return self._state_tuple_type
615
616  def call(self, inputs, state):
617    """Run one step of LSTM.
618
619    Args:
620      inputs: input Tensor, 2D, [batch, feature_size].
621      state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
622        flag self._state_is_tuple.
623
624    Returns:
625      A tuple containing:
626      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
627        after reading "inputs" when previous state was "state".
628        Here output_dim is num_units.
629      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
630        after reading "inputs" when previous state was "state".
631    Raises:
632      ValueError: if an input_size was specified and the provided inputs have
633        a different dimension.
634    """
635    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
636    freq_inputs = self._make_tf_features(inputs)
637    m_out_lst = []
638    state_out_lst = []
639    for block in range(len(freq_inputs)):
640      m_out_lst_current, state_out_lst_current = self._compute(
641          freq_inputs[block],
642          block,
643          state,
644          batch_size,
645          state_is_tuple=self._state_is_tuple)
646      m_out_lst.extend(m_out_lst_current)
647      state_out_lst.extend(state_out_lst_current)
648    if self._state_is_tuple:
649      state_out = self._state_tuple_type(*state_out_lst)
650    else:
651      state_out = array_ops.concat(state_out_lst, 1)
652    m_out = array_ops.concat(m_out_lst, 1)
653    return m_out, state_out
654
655  def _compute(self,
656               freq_inputs,
657               block,
658               state,
659               batch_size,
660               state_prefix="state",
661               state_is_tuple=True):
662    """Run the actual computation of one step LSTM.
663
664    Args:
665      freq_inputs: list of Tensors, 2D, [batch, feature_size].
666      block: int, current frequency block index to process.
667      state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
668        the flag state_is_tuple.
669      batch_size: int32, batch size.
670      state_prefix: (optional) string, name prefix for states, defaults to
671        "state".
672      state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
673
674    Returns:
675      A tuple, containing:
676      - A list of [batch, output_dim] Tensors, representing the output of the
677        LSTM given the inputs and state.
678      - A list of [batch, state_size] Tensors, representing the LSTM state
679        values given the inputs and previous state.
680    """
681    sigmoid = math_ops.sigmoid
682    tanh = math_ops.tanh
683    num_gates = 3 if self._couple_input_forget_gates else 4
684    dtype = freq_inputs[0].dtype
685    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
686
687    concat_w_f = _get_concat_variable(
688        "W_f_%d" % block,
689        [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
690        dtype, self._num_unit_shards)
691    b_f = vs.get_variable(
692        "B_f_%d" % block,
693        shape=[num_gates * self._num_units],
694        initializer=init_ops.zeros_initializer(),
695        dtype=dtype)
696    if not self._share_time_frequency_weights:
697      concat_w_t = _get_concat_variable("W_t_%d" % block, [
698          actual_input_size + 2 * self._num_units, num_gates * self._num_units
699      ], dtype, self._num_unit_shards)
700      b_t = vs.get_variable(
701          "B_t_%d" % block,
702          shape=[num_gates * self._num_units],
703          initializer=init_ops.zeros_initializer(),
704          dtype=dtype)
705
706    if self._use_peepholes:
707      # Diagonal connections
708      if not self._couple_input_forget_gates:
709        w_f_diag_freqf = vs.get_variable(
710            "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
711        w_f_diag_freqt = vs.get_variable(
712            "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
713      w_i_diag_freqf = vs.get_variable(
714          "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
715      w_i_diag_freqt = vs.get_variable(
716          "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
717      w_o_diag_freqf = vs.get_variable(
718          "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
719      w_o_diag_freqt = vs.get_variable(
720          "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
721      if not self._share_time_frequency_weights:
722        if not self._couple_input_forget_gates:
723          w_f_diag_timef = vs.get_variable(
724              "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
725          w_f_diag_timet = vs.get_variable(
726              "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
727        w_i_diag_timef = vs.get_variable(
728            "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
729        w_i_diag_timet = vs.get_variable(
730            "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
731        w_o_diag_timef = vs.get_variable(
732            "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
733        w_o_diag_timet = vs.get_variable(
734            "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
735
736    # initialize the first freq state to be zero
737    m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
738    c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
739    for freq_index in range(len(freq_inputs)):
740      if state_is_tuple:
741        name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
742        c_prev_time = getattr(state, name_prefix + "_c")
743        m_prev_time = getattr(state, name_prefix + "_m")
744      else:
745        c_prev_time = array_ops.slice(
746            state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
747        m_prev_time = array_ops.slice(
748            state, [0, (2 * freq_index + 1) * self._num_units],
749            [-1, self._num_units])
750
751      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
752      cell_inputs = array_ops.concat(
753          [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
754
755      # F-LSTM
756      lstm_matrix_freq = nn_ops.bias_add(
757          math_ops.matmul(cell_inputs, concat_w_f), b_f)
758      if self._couple_input_forget_gates:
759        i_freq, j_freq, o_freq = array_ops.split(
760            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
761        f_freq = None
762      else:
763        i_freq, j_freq, f_freq, o_freq = array_ops.split(
764            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
765      # T-LSTM
766      if self._share_time_frequency_weights:
767        i_time = i_freq
768        j_time = j_freq
769        f_time = f_freq
770        o_time = o_freq
771      else:
772        lstm_matrix_time = nn_ops.bias_add(
773            math_ops.matmul(cell_inputs, concat_w_t), b_t)
774        if self._couple_input_forget_gates:
775          i_time, j_time, o_time = array_ops.split(
776              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
777          f_time = None
778        else:
779          i_time, j_time, f_time, o_time = array_ops.split(
780              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
781
782      # F-LSTM c_freq
783      # input gate activations
784      if self._use_peepholes:
785        i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
786                           w_i_diag_freqt * c_prev_time)
787      else:
788        i_freq_g = sigmoid(i_freq)
789      # forget gate activations
790      if self._couple_input_forget_gates:
791        f_freq_g = 1.0 - i_freq_g
792      else:
793        if self._use_peepholes:
794          f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
795                             c_prev_freq + w_f_diag_freqt * c_prev_time)
796        else:
797          f_freq_g = sigmoid(f_freq + self._forget_bias)
798      # cell state
799      c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
800      if self._cell_clip is not None:
801        # pylint: disable=invalid-unary-operand-type
802        c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
803                                        self._cell_clip)
804        # pylint: enable=invalid-unary-operand-type
805
806      # T-LSTM c_freq
807      # input gate activations
808      if self._use_peepholes:
809        if self._share_time_frequency_weights:
810          i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
811                             w_i_diag_freqt * c_prev_time)
812        else:
813          i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
814                             w_i_diag_timet * c_prev_time)
815      else:
816        i_time_g = sigmoid(i_time)
817      # forget gate activations
818      if self._couple_input_forget_gates:
819        f_time_g = 1.0 - i_time_g
820      else:
821        if self._use_peepholes:
822          if self._share_time_frequency_weights:
823            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
824                               c_prev_freq + w_f_diag_freqt * c_prev_time)
825          else:
826            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
827                               c_prev_freq + w_f_diag_timet * c_prev_time)
828        else:
829          f_time_g = sigmoid(f_time + self._forget_bias)
830      # cell state
831      c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
832      if self._cell_clip is not None:
833        # pylint: disable=invalid-unary-operand-type
834        c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
835                                        self._cell_clip)
836        # pylint: enable=invalid-unary-operand-type
837
838      # F-LSTM m_freq
839      if self._use_peepholes:
840        m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
841                         w_o_diag_freqt * c_time) * tanh(c_freq)
842      else:
843        m_freq = sigmoid(o_freq) * tanh(c_freq)
844
845      # T-LSTM m_time
846      if self._use_peepholes:
847        if self._share_time_frequency_weights:
848          m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
849                           w_o_diag_freqt * c_time) * tanh(c_time)
850        else:
851          m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
852                           w_o_diag_timet * c_time) * tanh(c_time)
853      else:
854        m_time = sigmoid(o_time) * tanh(c_time)
855
856      m_prev_freq = m_freq
857      c_prev_freq = c_freq
858      # Concatenate the outputs for T-LSTM and F-LSTM for each shift
859      if freq_index == 0:
860        state_out_lst = [c_time, m_time]
861        m_out_lst = [m_time, m_freq]
862      else:
863        state_out_lst.extend([c_time, m_time])
864        m_out_lst.extend([m_time, m_freq])
865
866    return m_out_lst, state_out_lst
867
868  def _make_tf_features(self, input_feat, slice_offset=0):
869    """Make the frequency features.
870
871    Args:
872      input_feat: input Tensor, 2D, [batch, num_units].
873      slice_offset: (optional) Python int, default 0, the slicing offset is only
874        used for the backward processing in the BidirectionalGridLSTMCell. It
875        specifies a different starting point instead of always 0 to enable the
876        forward and backward processing look at different frequency blocks.
877
878    Returns:
879      A list of frequency features, with each element containing:
880      - A 2D, [batch, output_dim], Tensor representing the time-frequency
881        feature for that frequency index. Here output_dim is feature_size.
882    Raises:
883      ValueError: if input_size cannot be inferred from static shape inference.
884    """
885    input_size = input_feat.get_shape().with_rank(2)[-1].value
886    if input_size is None:
887      raise ValueError("Cannot infer input_size from static shape inference.")
888    if slice_offset > 0:
889      # Padding to the end
890      inputs = array_ops.pad(input_feat,
891                             array_ops.constant(
892                                 [0, 0, 0, slice_offset],
893                                 shape=[2, 2],
894                                 dtype=dtypes.int32), "CONSTANT")
895    elif slice_offset < 0:
896      # Padding to the front
897      inputs = array_ops.pad(input_feat,
898                             array_ops.constant(
899                                 [0, 0, -slice_offset, 0],
900                                 shape=[2, 2],
901                                 dtype=dtypes.int32), "CONSTANT")
902      slice_offset = 0
903    else:
904      inputs = input_feat
905    freq_inputs = []
906    if not self._start_freqindex_list:
907      if len(self._num_frequency_blocks) != 1:
908        raise ValueError("Length of num_frequency_blocks"
909                         " is not 1, but instead is %d",
910                         len(self._num_frequency_blocks))
911      num_feats = int(
912          (input_size - self._feature_size) / (self._frequency_skip)) + 1
913      if num_feats != self._num_frequency_blocks[0]:
914        raise ValueError(
915            "Invalid num_frequency_blocks, requires %d but gets %d, please"
916            " check the input size and filter config are correct." %
917            (self._num_frequency_blocks[0], num_feats))
918      block_inputs = []
919      for f in range(num_feats):
920        cur_input = array_ops.slice(
921            inputs, [0, slice_offset + f * self._frequency_skip],
922            [-1, self._feature_size])
923        block_inputs.append(cur_input)
924      freq_inputs.append(block_inputs)
925    else:
926      if len(self._start_freqindex_list) != len(self._end_freqindex_list):
927        raise ValueError("Length of start and end freqindex_list"
928                         " does not match %d %d",
929                         len(self._start_freqindex_list),
930                         len(self._end_freqindex_list))
931      if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
932        raise ValueError("Length of num_frequency_blocks"
933                         " is not equal to start_freqindex_list %d %d",
934                         len(self._num_frequency_blocks),
935                         len(self._start_freqindex_list))
936      for b in range(len(self._start_freqindex_list)):
937        start_index = self._start_freqindex_list[b]
938        end_index = self._end_freqindex_list[b]
939        cur_size = end_index - start_index
940        block_feats = int(
941            (cur_size - self._feature_size) / (self._frequency_skip)) + 1
942        if block_feats != self._num_frequency_blocks[b]:
943          raise ValueError(
944              "Invalid num_frequency_blocks, requires %d but gets %d, please"
945              " check the input size and filter config are correct." %
946              (self._num_frequency_blocks[b], block_feats))
947        block_inputs = []
948        for f in range(block_feats):
949          cur_input = array_ops.slice(
950              inputs,
951              [0, start_index + slice_offset + f * self._frequency_skip],
952              [-1, self._feature_size])
953          block_inputs.append(cur_input)
954        freq_inputs.append(block_inputs)
955    return freq_inputs
956
957
958class BidirectionalGridLSTMCell(GridLSTMCell):
959  """Bidirectional GridLstm cell.
960
961  The bidirection connection is only used in the frequency direction, which
962  hence doesn't affect the time direction's real-time processing that is
963  required for online recognition systems.
964  The current implementation uses different weights for the two directions.
965  """
966
967  def __init__(self,
968               num_units,
969               use_peepholes=False,
970               share_time_frequency_weights=False,
971               cell_clip=None,
972               initializer=None,
973               num_unit_shards=1,
974               forget_bias=1.0,
975               feature_size=None,
976               frequency_skip=None,
977               num_frequency_blocks=None,
978               start_freqindex_list=None,
979               end_freqindex_list=None,
980               couple_input_forget_gates=False,
981               backward_slice_offset=0,
982               reuse=None):
983    """Initialize the parameters for an LSTM cell.
984
985    Args:
986      num_units: int, The number of units in the LSTM cell
987      use_peepholes: (optional) bool, default False. Set True to enable
988        diagonal/peephole connections.
989      share_time_frequency_weights: (optional) bool, default False. Set True to
990        enable shared cell weights between time and frequency LSTMs.
991      cell_clip: (optional) A float value, default None, if provided the cell
992        state is clipped by this value prior to the cell output activation.
993      initializer: (optional) The initializer to use for the weight and
994        projection matrices, default None.
995      num_unit_shards: (optional) int, default 1, How to split the weight
996        matrix. If > 1,the weight matrix is stored across num_unit_shards.
997      forget_bias: (optional) float, default 1.0, The initial bias of the
998        forget gates, used to reduce the scale of forgetting at the beginning
999        of the training.
1000      feature_size: (optional) int, default None, The size of the input feature
1001        the LSTM spans over.
1002      frequency_skip: (optional) int, default None, The amount the LSTM filter
1003        is shifted by in frequency.
1004      num_frequency_blocks: [required] A list of frequency blocks needed to
1005        cover the whole input feature splitting defined by start_freqindex_list
1006        and end_freqindex_list.
1007      start_freqindex_list: [optional], list of ints, default None,  The
1008        starting frequency index for each frequency block.
1009      end_freqindex_list: [optional], list of ints, default None. The ending
1010        frequency index for each frequency block.
1011      couple_input_forget_gates: (optional) bool, default False, Whether to
1012        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
1013        model parameters and computation cost.
1014      backward_slice_offset: (optional) int32, default 0, the starting offset to
1015        slice the feature for backward processing.
1016      reuse: (optional) Python boolean describing whether to reuse variables
1017        in an existing scope.  If not `True`, and the existing scope already has
1018        the given variables, an error is raised.
1019    """
1020    super(BidirectionalGridLSTMCell, self).__init__(
1021        num_units, use_peepholes, share_time_frequency_weights, cell_clip,
1022        initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
1023        num_frequency_blocks, start_freqindex_list, end_freqindex_list,
1024        couple_input_forget_gates, True, reuse)
1025    self._backward_slice_offset = int(backward_slice_offset)
1026    state_names = ""
1027    for direction in ["fwd", "bwd"]:
1028      for block_index in range(len(self._num_frequency_blocks)):
1029        for freq_index in range(self._num_frequency_blocks[block_index]):
1030          name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
1031                                                  block_index)
1032          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
1033    self._state_tuple_type = collections.namedtuple(
1034        "BidirectionalGridLSTMStateTuple", state_names.strip(","))
1035    self._state_size = self._state_tuple_type(*(
1036        [num_units, num_units] * self._total_blocks * 2))
1037    self._output_size = 2 * num_units * self._total_blocks * 2
1038
1039  def call(self, inputs, state):
1040    """Run one step of LSTM.
1041
1042    Args:
1043      inputs: input Tensor, 2D, [batch, num_units].
1044      state: tuple of Tensors, 2D, [batch, state_size].
1045
1046    Returns:
1047      A tuple containing:
1048      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
1049        after reading "inputs" when previous state was "state".
1050        Here output_dim is num_units.
1051      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
1052        after reading "inputs" when previous state was "state".
1053    Raises:
1054      ValueError: if an input_size was specified and the provided inputs have
1055        a different dimension.
1056    """
1057    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
1058    fwd_inputs = self._make_tf_features(inputs)
1059    if self._backward_slice_offset:
1060      bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
1061    else:
1062      bwd_inputs = fwd_inputs
1063
1064    # Forward processing
1065    with vs.variable_scope("fwd"):
1066      fwd_m_out_lst = []
1067      fwd_state_out_lst = []
1068      for block in range(len(fwd_inputs)):
1069        fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
1070            fwd_inputs[block],
1071            block,
1072            state,
1073            batch_size,
1074            state_prefix="fwd_state",
1075            state_is_tuple=True)
1076        fwd_m_out_lst.extend(fwd_m_out_lst_current)
1077        fwd_state_out_lst.extend(fwd_state_out_lst_current)
1078    # Backward processing
1079    bwd_m_out_lst = []
1080    bwd_state_out_lst = []
1081    with vs.variable_scope("bwd"):
1082      for block in range(len(bwd_inputs)):
1083        # Reverse the blocks
1084        bwd_inputs_reverse = bwd_inputs[block][::-1]
1085        bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
1086            bwd_inputs_reverse,
1087            block,
1088            state,
1089            batch_size,
1090            state_prefix="bwd_state",
1091            state_is_tuple=True)
1092        bwd_m_out_lst.extend(bwd_m_out_lst_current)
1093        bwd_state_out_lst.extend(bwd_state_out_lst_current)
1094    state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
1095    # Outputs are always concated as it is never used separately.
1096    m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
1097    return m_out, state_out
1098
1099
1100# pylint: disable=protected-access
1101_Linear = core_rnn_cell._Linear  # pylint: disable=invalid-name
1102
1103# pylint: enable=protected-access
1104
1105
1106class AttentionCellWrapper(rnn_cell_impl.RNNCell):
1107  """Basic attention cell wrapper.
1108
1109  Implementation based on https://arxiv.org/abs/1409.0473.
1110  """
1111
1112  def __init__(self,
1113               cell,
1114               attn_length,
1115               attn_size=None,
1116               attn_vec_size=None,
1117               input_size=None,
1118               state_is_tuple=True,
1119               reuse=None):
1120    """Create a cell with attention.
1121
1122    Args:
1123      cell: an RNNCell, an attention is added to it.
1124      attn_length: integer, the size of an attention window.
1125      attn_size: integer, the size of an attention vector. Equal to
1126          cell.output_size by default.
1127      attn_vec_size: integer, the number of convolutional features calculated
1128          on attention state and a size of the hidden layer built from
1129          base cell state. Equal attn_size to by default.
1130      input_size: integer, the size of a hidden linear layer,
1131          built from inputs and attention. Derived from the input tensor
1132          by default.
1133      state_is_tuple: If True, accepted and returned states are n-tuples, where
1134        `n = len(cells)`.  By default (False), the states are all
1135        concatenated along the column axis.
1136      reuse: (optional) Python boolean describing whether to reuse variables
1137        in an existing scope.  If not `True`, and the existing scope already has
1138        the given variables, an error is raised.
1139
1140    Raises:
1141      TypeError: if cell is not an RNNCell.
1142      ValueError: if cell returns a state tuple but the flag
1143          `state_is_tuple` is `False` or if attn_length is zero or less.
1144    """
1145    super(AttentionCellWrapper, self).__init__(_reuse=reuse)
1146    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
1147      raise TypeError("The parameter cell is not RNNCell.")
1148    if nest.is_sequence(cell.state_size) and not state_is_tuple:
1149      raise ValueError(
1150          "Cell returns tuple of states, but the flag "
1151          "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
1152    if attn_length <= 0:
1153      raise ValueError(
1154          "attn_length should be greater than zero, got %s" % str(attn_length))
1155    if not state_is_tuple:
1156      logging.warn("%s: Using a concatenated state is slower and will soon be "
1157                   "deprecated.  Use state_is_tuple=True.", self)
1158    if attn_size is None:
1159      attn_size = cell.output_size
1160    if attn_vec_size is None:
1161      attn_vec_size = attn_size
1162    self._state_is_tuple = state_is_tuple
1163    self._cell = cell
1164    self._attn_vec_size = attn_vec_size
1165    self._input_size = input_size
1166    self._attn_size = attn_size
1167    self._attn_length = attn_length
1168    self._reuse = reuse
1169    self._linear1 = None
1170    self._linear2 = None
1171    self._linear3 = None
1172
1173  @property
1174  def state_size(self):
1175    size = (self._cell.state_size, self._attn_size,
1176            self._attn_size * self._attn_length)
1177    if self._state_is_tuple:
1178      return size
1179    else:
1180      return sum(list(size))
1181
1182  @property
1183  def output_size(self):
1184    return self._attn_size
1185
1186  def call(self, inputs, state):
1187    """Long short-term memory cell with attention (LSTMA)."""
1188    if self._state_is_tuple:
1189      state, attns, attn_states = state
1190    else:
1191      states = state
1192      state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
1193      attns = array_ops.slice(states, [0, self._cell.state_size],
1194                              [-1, self._attn_size])
1195      attn_states = array_ops.slice(
1196          states, [0, self._cell.state_size + self._attn_size],
1197          [-1, self._attn_size * self._attn_length])
1198    attn_states = array_ops.reshape(attn_states,
1199                                    [-1, self._attn_length, self._attn_size])
1200    input_size = self._input_size
1201    if input_size is None:
1202      input_size = inputs.get_shape().as_list()[1]
1203    if self._linear1 is None:
1204      self._linear1 = _Linear([inputs, attns], input_size, True)
1205    inputs = self._linear1([inputs, attns])
1206    cell_output, new_state = self._cell(inputs, state)
1207    if self._state_is_tuple:
1208      new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
1209    else:
1210      new_state_cat = new_state
1211    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
1212    with vs.variable_scope("attn_output_projection"):
1213      if self._linear2 is None:
1214        self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True)
1215      output = self._linear2([cell_output, new_attns])
1216    new_attn_states = array_ops.concat(
1217        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
1218    new_attn_states = array_ops.reshape(
1219        new_attn_states, [-1, self._attn_length * self._attn_size])
1220    new_state = (new_state, new_attns, new_attn_states)
1221    if not self._state_is_tuple:
1222      new_state = array_ops.concat(list(new_state), 1)
1223    return output, new_state
1224
1225  def _attention(self, query, attn_states):
1226    conv2d = nn_ops.conv2d
1227    reduce_sum = math_ops.reduce_sum
1228    softmax = nn_ops.softmax
1229    tanh = math_ops.tanh
1230
1231    with vs.variable_scope("attention"):
1232      k = vs.get_variable("attn_w",
1233                          [1, 1, self._attn_size, self._attn_vec_size])
1234      v = vs.get_variable("attn_v", [self._attn_vec_size])
1235      hidden = array_ops.reshape(attn_states,
1236                                 [-1, self._attn_length, 1, self._attn_size])
1237      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
1238      if self._linear3 is None:
1239        self._linear3 = _Linear(query, self._attn_vec_size, True)
1240      y = self._linear3(query)
1241      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
1242      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
1243      a = softmax(s)
1244      d = reduce_sum(
1245          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
1246      new_attns = array_ops.reshape(d, [-1, self._attn_size])
1247      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
1248      return new_attns, new_attn_states
1249
1250
1251class HighwayWrapper(rnn_cell_impl.RNNCell):
1252  """RNNCell wrapper that adds highway connection on cell input and output.
1253
1254  Based on:
1255    R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
1256    arXiv preprint arXiv:1505.00387, 2015.
1257    https://arxiv.org/abs/1505.00387
1258  """
1259
1260  def __init__(self,
1261               cell,
1262               couple_carry_transform_gates=True,
1263               carry_bias_init=1.0):
1264    """Constructs a `HighwayWrapper` for `cell`.
1265
1266    Args:
1267      cell: An instance of `RNNCell`.
1268      couple_carry_transform_gates: boolean, should the Carry and Transform gate
1269        be coupled.
1270      carry_bias_init: float, carry gates bias initialization.
1271    """
1272    self._cell = cell
1273    self._couple_carry_transform_gates = couple_carry_transform_gates
1274    self._carry_bias_init = carry_bias_init
1275
1276  @property
1277  def state_size(self):
1278    return self._cell.state_size
1279
1280  @property
1281  def output_size(self):
1282    return self._cell.output_size
1283
1284  def zero_state(self, batch_size, dtype):
1285    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1286      return self._cell.zero_state(batch_size, dtype)
1287
1288  def _highway(self, inp, out):
1289    input_size = inp.get_shape().with_rank(2)[1].value
1290    carry_weight = vs.get_variable("carry_w", [input_size, input_size])
1291    carry_bias = vs.get_variable(
1292        "carry_b", [input_size],
1293        initializer=init_ops.constant_initializer(self._carry_bias_init))
1294    carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
1295    if self._couple_carry_transform_gates:
1296      transform = 1 - carry
1297    else:
1298      transform_weight = vs.get_variable("transform_w",
1299                                         [input_size, input_size])
1300      transform_bias = vs.get_variable(
1301          "transform_b", [input_size],
1302          initializer=init_ops.constant_initializer(-self._carry_bias_init))
1303      transform = math_ops.sigmoid(
1304          nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
1305    return inp * carry + out * transform
1306
1307  def __call__(self, inputs, state, scope=None):
1308    """Run the cell and add its inputs to its outputs.
1309
1310    Args:
1311      inputs: cell inputs.
1312      state: cell state.
1313      scope: optional cell scope.
1314
1315    Returns:
1316      Tuple of cell outputs and new state.
1317
1318    Raises:
1319      TypeError: If cell inputs and outputs have different structure (type).
1320      ValueError: If cell inputs and outputs have different structure (value).
1321    """
1322    outputs, new_state = self._cell(inputs, state, scope=scope)
1323    nest.assert_same_structure(inputs, outputs)
1324
1325    # Ensure shapes match
1326    def assert_shape_match(inp, out):
1327      inp.get_shape().assert_is_compatible_with(out.get_shape())
1328
1329    nest.map_structure(assert_shape_match, inputs, outputs)
1330    res_outputs = nest.map_structure(self._highway, inputs, outputs)
1331    return (res_outputs, new_state)
1332
1333
1334class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
1335  """LSTM unit with layer normalization and recurrent dropout.
1336
1337  This class adds layer normalization and recurrent dropout to a
1338  basic LSTM unit. Layer normalization implementation is based on:
1339
1340    https://arxiv.org/abs/1607.06450.
1341
1342  "Layer Normalization"
1343  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
1344
1345  and is applied before the internal nonlinearities.
1346  Recurrent dropout is base on:
1347
1348    https://arxiv.org/abs/1603.05118
1349
1350  "Recurrent Dropout without Memory Loss"
1351  Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
1352  """
1353
1354  def __init__(self,
1355               num_units,
1356               forget_bias=1.0,
1357               input_size=None,
1358               activation=math_ops.tanh,
1359               layer_norm=True,
1360               norm_gain=1.0,
1361               norm_shift=0.0,
1362               dropout_keep_prob=1.0,
1363               dropout_prob_seed=None,
1364               reuse=None):
1365    """Initializes the basic LSTM cell.
1366
1367    Args:
1368      num_units: int, The number of units in the LSTM cell.
1369      forget_bias: float, The bias added to forget gates (see above).
1370      input_size: Deprecated and unused.
1371      activation: Activation function of the inner states.
1372      layer_norm: If `True`, layer normalization will be applied.
1373      norm_gain: float, The layer normalization gain initial value. If
1374        `layer_norm` has been set to `False`, this argument will be ignored.
1375      norm_shift: float, The layer normalization shift initial value. If
1376        `layer_norm` has been set to `False`, this argument will be ignored.
1377      dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
1378        recurrent dropout probability value. If float and 1.0, no dropout will
1379        be applied.
1380      dropout_prob_seed: (optional) integer, the randomness seed.
1381      reuse: (optional) Python boolean describing whether to reuse variables
1382        in an existing scope.  If not `True`, and the existing scope already has
1383        the given variables, an error is raised.
1384    """
1385    super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
1386
1387    if input_size is not None:
1388      logging.warn("%s: The input_size parameter is deprecated.", self)
1389
1390    self._num_units = num_units
1391    self._activation = activation
1392    self._forget_bias = forget_bias
1393    self._keep_prob = dropout_keep_prob
1394    self._seed = dropout_prob_seed
1395    self._layer_norm = layer_norm
1396    self._norm_gain = norm_gain
1397    self._norm_shift = norm_shift
1398    self._reuse = reuse
1399
1400  @property
1401  def state_size(self):
1402    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1403
1404  @property
1405  def output_size(self):
1406    return self._num_units
1407
1408  def _norm(self, inp, scope, dtype=dtypes.float32):
1409    shape = inp.get_shape()[-1:]
1410    gamma_init = init_ops.constant_initializer(self._norm_gain)
1411    beta_init = init_ops.constant_initializer(self._norm_shift)
1412    with vs.variable_scope(scope):
1413      # Initialize beta and gamma for use by layer_norm.
1414      vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
1415      vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
1416    normalized = layers.layer_norm(inp, reuse=True, scope=scope)
1417    return normalized
1418
1419  def _linear(self, args):
1420    out_size = 4 * self._num_units
1421    proj_size = args.get_shape()[-1]
1422    dtype = args.dtype
1423    weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
1424    out = math_ops.matmul(args, weights)
1425    if not self._layer_norm:
1426      bias = vs.get_variable("bias", [out_size], dtype=dtype)
1427      out = nn_ops.bias_add(out, bias)
1428    return out
1429
1430  def call(self, inputs, state):
1431    """LSTM cell with layer normalization and recurrent dropout."""
1432    c, h = state
1433    args = array_ops.concat([inputs, h], 1)
1434    concat = self._linear(args)
1435    dtype = args.dtype
1436
1437    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
1438    if self._layer_norm:
1439      i = self._norm(i, "input", dtype=dtype)
1440      j = self._norm(j, "transform", dtype=dtype)
1441      f = self._norm(f, "forget", dtype=dtype)
1442      o = self._norm(o, "output", dtype=dtype)
1443
1444    g = self._activation(j)
1445    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
1446      g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
1447
1448    new_c = (
1449        c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
1450    if self._layer_norm:
1451      new_c = self._norm(new_c, "state", dtype=dtype)
1452    new_h = self._activation(new_c) * math_ops.sigmoid(o)
1453
1454    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
1455    return new_h, new_state
1456
1457
1458class NASCell(rnn_cell_impl.RNNCell):
1459  """Neural Architecture Search (NAS) recurrent network cell.
1460
1461  This implements the recurrent cell from the paper:
1462
1463    https://arxiv.org/abs/1611.01578
1464
1465  Barret Zoph and Quoc V. Le.
1466  "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
1467
1468  The class uses an optional projection layer.
1469  """
1470
1471  def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None):
1472    """Initialize the parameters for a NAS cell.
1473
1474    Args:
1475      num_units: int, The number of units in the NAS cell
1476      num_proj: (optional) int, The output dimensionality for the projection
1477        matrices.  If None, no projection is performed.
1478      use_biases: (optional) bool, If True then use biases within the cell. This
1479        is False by default.
1480      reuse: (optional) Python boolean describing whether to reuse variables
1481        in an existing scope.  If not `True`, and the existing scope already has
1482        the given variables, an error is raised.
1483    """
1484    super(NASCell, self).__init__(_reuse=reuse)
1485    self._num_units = num_units
1486    self._num_proj = num_proj
1487    self._use_biases = use_biases
1488    self._reuse = reuse
1489
1490    if num_proj is not None:
1491      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
1492      self._output_size = num_proj
1493    else:
1494      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
1495      self._output_size = num_units
1496
1497  @property
1498  def state_size(self):
1499    return self._state_size
1500
1501  @property
1502  def output_size(self):
1503    return self._output_size
1504
1505  def call(self, inputs, state):
1506    """Run one step of NAS Cell.
1507
1508    Args:
1509      inputs: input Tensor, 2D, batch x num_units.
1510      state: This must be a tuple of state Tensors, both `2-D`, with column
1511        sizes `c_state` and `m_state`.
1512
1513    Returns:
1514      A tuple containing:
1515      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
1516        NAS Cell after reading `inputs` when previous state was `state`.
1517        Here output_dim is:
1518           num_proj if num_proj was set,
1519           num_units otherwise.
1520      - Tensor(s) representing the new state of NAS Cell after reading `inputs`
1521        when the previous state was `state`.  Same type and shape(s) as `state`.
1522
1523    Raises:
1524      ValueError: If input size cannot be inferred from inputs via
1525        static shape inference.
1526    """
1527    sigmoid = math_ops.sigmoid
1528    tanh = math_ops.tanh
1529    relu = nn_ops.relu
1530
1531    num_proj = self._num_units if self._num_proj is None else self._num_proj
1532
1533    (c_prev, m_prev) = state
1534
1535    dtype = inputs.dtype
1536    input_size = inputs.get_shape().with_rank(2)[1]
1537    if input_size.value is None:
1538      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1539    # Variables for the NAS cell. W_m is all matrices multiplying the
1540    # hiddenstate and W_inputs is all matrices multiplying the inputs.
1541    concat_w_m = vs.get_variable("recurrent_kernel",
1542                                 [num_proj, 8 * self._num_units], dtype)
1543    concat_w_inputs = vs.get_variable(
1544        "kernel", [input_size.value, 8 * self._num_units], dtype)
1545
1546    m_matrix = math_ops.matmul(m_prev, concat_w_m)
1547    inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
1548
1549    if self._use_biases:
1550      b = vs.get_variable(
1551          "bias",
1552          shape=[8 * self._num_units],
1553          initializer=init_ops.zeros_initializer(),
1554          dtype=dtype)
1555      m_matrix = nn_ops.bias_add(m_matrix, b)
1556
1557    # The NAS cell branches into 8 different splits for both the hiddenstate
1558    # and the input
1559    m_matrix_splits = array_ops.split(
1560        axis=1, num_or_size_splits=8, value=m_matrix)
1561    inputs_matrix_splits = array_ops.split(
1562        axis=1, num_or_size_splits=8, value=inputs_matrix)
1563
1564    # First layer
1565    layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
1566    layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
1567    layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
1568    layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
1569    layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
1570    layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
1571    layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
1572    layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
1573
1574    # Second layer
1575    l2_0 = tanh(layer1_0 * layer1_1)
1576    l2_1 = tanh(layer1_2 + layer1_3)
1577    l2_2 = tanh(layer1_4 * layer1_5)
1578    l2_3 = sigmoid(layer1_6 + layer1_7)
1579
1580    # Inject the cell
1581    l2_0 = tanh(l2_0 + c_prev)
1582
1583    # Third layer
1584    l3_0_pre = l2_0 * l2_1
1585    new_c = l3_0_pre  # create new cell
1586    l3_0 = l3_0_pre
1587    l3_1 = tanh(l2_2 + l2_3)
1588
1589    # Final layer
1590    new_m = tanh(l3_0 * l3_1)
1591
1592    # Projection layer if specified
1593    if self._num_proj is not None:
1594      concat_w_proj = vs.get_variable("projection_weights",
1595                                      [self._num_units, self._num_proj], dtype)
1596      new_m = math_ops.matmul(new_m, concat_w_proj)
1597
1598    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
1599    return new_m, new_state
1600
1601
1602class UGRNNCell(rnn_cell_impl.RNNCell):
1603  """Update Gate Recurrent Neural Network (UGRNN) cell.
1604
1605  Compromise between a LSTM/GRU and a vanilla RNN.  There is only one
1606  gate, and that is to determine whether the unit should be
1607  integrating or computing instantaneously.  This is the recurrent
1608  idea of the feedforward Highway Network.
1609
1610  This implements the recurrent cell from the paper:
1611
1612    https://arxiv.org/abs/1611.09913
1613
1614  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1615  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1616  """
1617
1618  def __init__(self,
1619               num_units,
1620               initializer=None,
1621               forget_bias=1.0,
1622               activation=math_ops.tanh,
1623               reuse=None):
1624    """Initialize the parameters for an UGRNN cell.
1625
1626    Args:
1627      num_units: int, The number of units in the UGRNN cell
1628      initializer: (optional) The initializer to use for the weight matrices.
1629      forget_bias: (optional) float, default 1.0, The initial bias of the
1630        forget gate, used to reduce the scale of forgetting at the beginning
1631        of the training.
1632      activation: (optional) Activation function of the inner states.
1633        Default is `tf.tanh`.
1634      reuse: (optional) Python boolean describing whether to reuse variables
1635        in an existing scope.  If not `True`, and the existing scope already has
1636        the given variables, an error is raised.
1637    """
1638    super(UGRNNCell, self).__init__(_reuse=reuse)
1639    self._num_units = num_units
1640    self._initializer = initializer
1641    self._forget_bias = forget_bias
1642    self._activation = activation
1643    self._reuse = reuse
1644    self._linear = None
1645
1646  @property
1647  def state_size(self):
1648    return self._num_units
1649
1650  @property
1651  def output_size(self):
1652    return self._num_units
1653
1654  def call(self, inputs, state):
1655    """Run one step of UGRNN.
1656
1657    Args:
1658      inputs: input Tensor, 2D, batch x input size.
1659      state: state Tensor, 2D, batch x num units.
1660
1661    Returns:
1662      new_output: batch x num units, Tensor representing the output of the UGRNN
1663        after reading `inputs` when previous state was `state`. Identical to
1664        `new_state`.
1665      new_state: batch x num units, Tensor representing the state of the UGRNN
1666        after reading `inputs` when previous state was `state`.
1667
1668    Raises:
1669      ValueError: If input size cannot be inferred from inputs via
1670        static shape inference.
1671    """
1672    sigmoid = math_ops.sigmoid
1673
1674    input_size = inputs.get_shape().with_rank(2)[1]
1675    if input_size.value is None:
1676      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1677
1678    with vs.variable_scope(
1679        vs.get_variable_scope(), initializer=self._initializer):
1680      cell_inputs = array_ops.concat([inputs, state], 1)
1681      if self._linear is None:
1682        self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
1683      rnn_matrix = self._linear(cell_inputs)
1684
1685      [g_act, c_act] = array_ops.split(
1686          axis=1, num_or_size_splits=2, value=rnn_matrix)
1687
1688      c = self._activation(c_act)
1689      g = sigmoid(g_act + self._forget_bias)
1690      new_state = g * state + (1.0 - g) * c
1691      new_output = new_state
1692
1693    return new_output, new_state
1694
1695
1696class IntersectionRNNCell(rnn_cell_impl.RNNCell):
1697  """Intersection Recurrent Neural Network (+RNN) cell.
1698
1699  Architecture with coupled recurrent gate as well as coupled depth
1700  gate, designed to improve information flow through stacked RNNs. As the
1701  architecture uses depth gating, the dimensionality of the depth
1702  output (y) also should not change through depth (input size == output size).
1703  To achieve this, the first layer of a stacked Intersection RNN projects
1704  the inputs to N (num units) dimensions. Therefore when initializing an
1705  IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
1706  and use default settings for subsequent layers.
1707
1708  This implements the recurrent cell from the paper:
1709
1710    https://arxiv.org/abs/1611.09913
1711
1712  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1713  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1714
1715  The Intersection RNN is built for use in deeply stacked
1716  RNNs so it may not achieve best performance with depth 1.
1717  """
1718
1719  def __init__(self,
1720               num_units,
1721               num_in_proj=None,
1722               initializer=None,
1723               forget_bias=1.0,
1724               y_activation=nn_ops.relu,
1725               reuse=None):
1726    """Initialize the parameters for an +RNN cell.
1727
1728    Args:
1729      num_units: int, The number of units in the +RNN cell
1730      num_in_proj: (optional) int, The input dimensionality for the RNN.
1731        If creating the first layer of an +RNN, this should be set to
1732        `num_units`. Otherwise, this should be set to `None` (default).
1733        If `None`, dimensionality of `inputs` should be equal to `num_units`,
1734        otherwise ValueError is thrown.
1735      initializer: (optional) The initializer to use for the weight matrices.
1736      forget_bias: (optional) float, default 1.0, The initial bias of the
1737        forget gates, used to reduce the scale of forgetting at the beginning
1738        of the training.
1739      y_activation: (optional) Activation function of the states passed
1740        through depth. Default is 'tf.nn.relu`.
1741      reuse: (optional) Python boolean describing whether to reuse variables
1742        in an existing scope.  If not `True`, and the existing scope already has
1743        the given variables, an error is raised.
1744    """
1745    super(IntersectionRNNCell, self).__init__(_reuse=reuse)
1746    self._num_units = num_units
1747    self._initializer = initializer
1748    self._forget_bias = forget_bias
1749    self._num_input_proj = num_in_proj
1750    self._y_activation = y_activation
1751    self._reuse = reuse
1752    self._linear1 = None
1753    self._linear2 = None
1754
1755  @property
1756  def state_size(self):
1757    return self._num_units
1758
1759  @property
1760  def output_size(self):
1761    return self._num_units
1762
1763  def call(self, inputs, state):
1764    """Run one step of the Intersection RNN.
1765
1766    Args:
1767      inputs: input Tensor, 2D, batch x input size.
1768      state: state Tensor, 2D, batch x num units.
1769
1770    Returns:
1771      new_y: batch x num units, Tensor representing the output of the +RNN
1772        after reading `inputs` when previous state was `state`.
1773      new_state: batch x num units, Tensor representing the state of the +RNN
1774        after reading `inputs` when previous state was `state`.
1775
1776    Raises:
1777      ValueError: If input size cannot be inferred from `inputs` via
1778        static shape inference.
1779      ValueError: If input size != output size (these must be equal when
1780        using the Intersection RNN).
1781    """
1782    sigmoid = math_ops.sigmoid
1783    tanh = math_ops.tanh
1784
1785    input_size = inputs.get_shape().with_rank(2)[1]
1786    if input_size.value is None:
1787      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1788
1789    with vs.variable_scope(
1790        vs.get_variable_scope(), initializer=self._initializer):
1791      # read-in projections (should be used for first layer in deep +RNN
1792      # to transform size of inputs from I --> N)
1793      if input_size.value != self._num_units:
1794        if self._num_input_proj:
1795          with vs.variable_scope("in_projection"):
1796            if self._linear1 is None:
1797              self._linear1 = _Linear(inputs, self._num_units, True)
1798            inputs = self._linear1(inputs)
1799        else:
1800          raise ValueError("Must have input size == output size for "
1801                           "Intersection RNN. To fix, num_in_proj should "
1802                           "be set to num_units at cell init.")
1803
1804      n_dim = i_dim = self._num_units
1805      cell_inputs = array_ops.concat([inputs, state], 1)
1806      if self._linear2 is None:
1807        self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
1808      rnn_matrix = self._linear2(cell_inputs)
1809
1810      gh_act = rnn_matrix[:, :n_dim]  # b x n
1811      h_act = rnn_matrix[:, n_dim:2 * n_dim]  # b x n
1812      gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim]  # b x i
1813      y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim]  # b x i
1814
1815      h = tanh(h_act)
1816      y = self._y_activation(y_act)
1817      gh = sigmoid(gh_act + self._forget_bias)
1818      gy = sigmoid(gy_act + self._forget_bias)
1819
1820      new_state = gh * state + (1.0 - gh) * h  # passed thru time
1821      new_y = gy * inputs + (1.0 - gy) * y  # passed thru depth
1822
1823    return new_y, new_state
1824
1825
1826_REGISTERED_OPS = None
1827
1828
1829class CompiledWrapper(rnn_cell_impl.RNNCell):
1830  """Wraps step execution in an XLA JIT scope."""
1831
1832  def __init__(self, cell, compile_stateful=False):
1833    """Create CompiledWrapper cell.
1834
1835    Args:
1836      cell: Instance of `RNNCell`.
1837      compile_stateful: Whether to compile stateful ops like initializers
1838        and random number generators (default: False).
1839    """
1840    self._cell = cell
1841    self._compile_stateful = compile_stateful
1842
1843  @property
1844  def state_size(self):
1845    return self._cell.state_size
1846
1847  @property
1848  def output_size(self):
1849    return self._cell.output_size
1850
1851  def zero_state(self, batch_size, dtype):
1852    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1853      return self._cell.zero_state(batch_size, dtype)
1854
1855  def __call__(self, inputs, state, scope=None):
1856    if self._compile_stateful:
1857      compile_ops = True
1858    else:
1859
1860      def compile_ops(node_def):
1861        global _REGISTERED_OPS
1862        if _REGISTERED_OPS is None:
1863          _REGISTERED_OPS = op_def_registry.get_registered_ops()
1864        return not _REGISTERED_OPS[node_def.op].is_stateful
1865
1866    with jit.experimental_jit_scope(compile_ops=compile_ops):
1867      return self._cell(inputs, state, scope=scope)
1868
1869
1870def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
1871  """Returns an exponential distribution initializer.
1872
1873  Args:
1874    minval: float or a scalar float Tensor. With value > 0. Lower bound of the
1875        range of random values to generate.
1876    maxval: float or a scalar float Tensor. With value > minval. Upper bound of
1877        the range of random values to generate.
1878    seed: An integer. Used to create random seeds.
1879    dtype: The data type.
1880
1881  Returns:
1882    An initializer that generates tensors with an exponential distribution.
1883  """
1884
1885  def _initializer(shape, dtype=dtype, partition_info=None):
1886    del partition_info  # Unused.
1887    return math_ops.exp(
1888        random_ops.random_uniform(
1889            shape, math_ops.log(minval), math_ops.log(maxval), dtype,
1890            seed=seed))
1891
1892  return _initializer
1893
1894
1895class PhasedLSTMCell(rnn_cell_impl.RNNCell):
1896  """Phased LSTM recurrent network cell.
1897
1898  https://arxiv.org/pdf/1610.09513v1.pdf
1899  """
1900
1901  def __init__(self,
1902               num_units,
1903               use_peepholes=False,
1904               leak=0.001,
1905               ratio_on=0.1,
1906               trainable_ratio_on=True,
1907               period_init_min=1.0,
1908               period_init_max=1000.0,
1909               reuse=None):
1910    """Initialize the Phased LSTM cell.
1911
1912    Args:
1913      num_units: int, The number of units in the Phased LSTM cell.
1914      use_peepholes: bool, set True to enable peephole connections.
1915      leak: float or scalar float Tensor with value in [0, 1]. Leak applied
1916          during training.
1917      ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
1918          period during which the gates are open.
1919      trainable_ratio_on: bool, weather ratio_on is trainable.
1920      period_init_min: float or scalar float Tensor. With value > 0.
1921          Minimum value of the initialized period.
1922          The period values are initialized by drawing from the distribution:
1923          e^U(log(period_init_min), log(period_init_max))
1924          Where U(.,.) is the uniform distribution.
1925      period_init_max: float or scalar float Tensor.
1926          With value > period_init_min. Maximum value of the initialized period.
1927      reuse: (optional) Python boolean describing whether to reuse variables
1928        in an existing scope. If not `True`, and the existing scope already has
1929        the given variables, an error is raised.
1930    """
1931    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
1932    self._num_units = num_units
1933    self._use_peepholes = use_peepholes
1934    self._leak = leak
1935    self._ratio_on = ratio_on
1936    self._trainable_ratio_on = trainable_ratio_on
1937    self._period_init_min = period_init_min
1938    self._period_init_max = period_init_max
1939    self._reuse = reuse
1940    self._linear1 = None
1941    self._linear2 = None
1942    self._linear3 = None
1943
1944  @property
1945  def state_size(self):
1946    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1947
1948  @property
1949  def output_size(self):
1950    return self._num_units
1951
1952  def _mod(self, x, y):
1953    """Modulo function that propagates x gradients."""
1954    return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
1955
1956  def _get_cycle_ratio(self, time, phase, period):
1957    """Compute the cycle ratio in the dtype of the time."""
1958    phase_casted = math_ops.cast(phase, dtype=time.dtype)
1959    period_casted = math_ops.cast(period, dtype=time.dtype)
1960    shifted_time = time - phase_casted
1961    cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
1962    return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
1963
1964  def call(self, inputs, state):
1965    """Phased LSTM Cell.
1966
1967    Args:
1968      inputs: A tuple of 2 Tensor.
1969         The first Tensor has shape [batch, 1], and type float32 or float64.
1970         It stores the time.
1971         The second Tensor has shape [batch, features_size], and type float32.
1972         It stores the features.
1973      state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
1974
1975    Returns:
1976      A tuple containing:
1977      - A Tensor of float32, and shape [batch_size, num_units], representing the
1978        output of the cell.
1979      - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
1980        [batch_size, num_units], representing the new state and the output.
1981    """
1982    (c_prev, h_prev) = state
1983    (time, x) = inputs
1984
1985    in_mask_gates = [x, h_prev]
1986    if self._use_peepholes:
1987      in_mask_gates.append(c_prev)
1988
1989    with vs.variable_scope("mask_gates"):
1990      if self._linear1 is None:
1991        self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
1992
1993      mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
1994      [input_gate, forget_gate] = array_ops.split(
1995          axis=1, num_or_size_splits=2, value=mask_gates)
1996
1997    with vs.variable_scope("new_input"):
1998      if self._linear2 is None:
1999        self._linear2 = _Linear([x, h_prev], self._num_units, True)
2000      new_input = math_ops.tanh(self._linear2([x, h_prev]))
2001
2002    new_c = (c_prev * forget_gate + input_gate * new_input)
2003
2004    in_out_gate = [x, h_prev]
2005    if self._use_peepholes:
2006      in_out_gate.append(new_c)
2007
2008    with vs.variable_scope("output_gate"):
2009      if self._linear3 is None:
2010        self._linear3 = _Linear(in_out_gate, self._num_units, True)
2011      output_gate = math_ops.sigmoid(self._linear3(in_out_gate))
2012
2013    new_h = math_ops.tanh(new_c) * output_gate
2014
2015    period = vs.get_variable(
2016        "period", [self._num_units],
2017        initializer=_random_exp_initializer(self._period_init_min,
2018                                            self._period_init_max))
2019    phase = vs.get_variable(
2020        "phase", [self._num_units],
2021        initializer=init_ops.random_uniform_initializer(0.,
2022                                                        period.initial_value))
2023    ratio_on = vs.get_variable(
2024        "ratio_on", [self._num_units],
2025        initializer=init_ops.constant_initializer(self._ratio_on),
2026        trainable=self._trainable_ratio_on)
2027
2028    cycle_ratio = self._get_cycle_ratio(time, phase, period)
2029
2030    k_up = 2 * cycle_ratio / ratio_on
2031    k_down = 2 - k_up
2032    k_closed = self._leak * cycle_ratio
2033
2034    k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
2035    k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
2036
2037    new_c = k * new_c + (1 - k) * c_prev
2038    new_h = k * new_h + (1 - k) * h_prev
2039
2040    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
2041
2042    return new_h, new_state
2043
2044
2045class ConvLSTMCell(rnn_cell_impl.RNNCell):
2046  """Convolutional LSTM recurrent network cell.
2047
2048  https://arxiv.org/pdf/1506.04214v1.pdf
2049  """
2050
2051  def __init__(self,
2052               conv_ndims,
2053               input_shape,
2054               output_channels,
2055               kernel_shape,
2056               use_bias=True,
2057               skip_connection=False,
2058               forget_bias=1.0,
2059               initializers=None,
2060               name="conv_lstm_cell"):
2061    """Construct ConvLSTMCell.
2062    Args:
2063      conv_ndims: Convolution dimensionality (1, 2 or 3).
2064      input_shape: Shape of the input as int tuple, excluding the batch size.
2065      output_channels: int, number of output channels of the conv LSTM.
2066      kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
2067      use_bias: Use bias in convolutions.
2068      skip_connection: If set to `True`, concatenate the input to the
2069      output of the conv LSTM. Default: `False`.
2070      forget_bias: Forget bias.
2071      name: Name of the module.
2072    Raises:
2073      ValueError: If `skip_connection` is `True` and stride is different from 1
2074        or if `input_shape` is incompatible with `conv_ndims`.
2075    """
2076    super(ConvLSTMCell, self).__init__(name=name)
2077
2078    if conv_ndims != len(input_shape) - 1:
2079      raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
2080          input_shape, conv_ndims))
2081
2082    self._conv_ndims = conv_ndims
2083    self._input_shape = input_shape
2084    self._output_channels = output_channels
2085    self._kernel_shape = kernel_shape
2086    self._use_bias = use_bias
2087    self._forget_bias = forget_bias
2088    self._skip_connection = skip_connection
2089
2090    self._total_output_channels = output_channels
2091    if self._skip_connection:
2092      self._total_output_channels += self._input_shape[-1]
2093
2094    state_size = tensor_shape.TensorShape(
2095        self._input_shape[:-1] + [self._output_channels])
2096    self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
2097    self._output_size = tensor_shape.TensorShape(
2098        self._input_shape[:-1] + [self._total_output_channels])
2099
2100  @property
2101  def output_size(self):
2102    return self._output_size
2103
2104  @property
2105  def state_size(self):
2106    return self._state_size
2107
2108  def call(self, inputs, state, scope=None):
2109    cell, hidden = state
2110    new_hidden = _conv([inputs, hidden], self._kernel_shape,
2111                       4 * self._output_channels, self._use_bias)
2112    gates = array_ops.split(
2113        value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
2114
2115    input_gate, new_input, forget_gate, output_gate = gates
2116    new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
2117    new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
2118    output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
2119
2120    if self._skip_connection:
2121      output = array_ops.concat([output, inputs], axis=-1)
2122    new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
2123    return output, new_state
2124
2125
2126class Conv1DLSTMCell(ConvLSTMCell):
2127  """1D Convolutional LSTM recurrent network cell.
2128
2129  https://arxiv.org/pdf/1506.04214v1.pdf
2130  """
2131
2132  def __init__(self, name="conv_1d_lstm_cell", **kwargs):
2133    """Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
2134    super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
2135
2136
2137class Conv2DLSTMCell(ConvLSTMCell):
2138  """2D Convolutional LSTM recurrent network cell.
2139
2140  https://arxiv.org/pdf/1506.04214v1.pdf
2141  """
2142
2143  def __init__(self, name="conv_2d_lstm_cell", **kwargs):
2144    """Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
2145    super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
2146
2147
2148class Conv3DLSTMCell(ConvLSTMCell):
2149  """3D Convolutional LSTM recurrent network cell.
2150
2151  https://arxiv.org/pdf/1506.04214v1.pdf
2152  """
2153
2154  def __init__(self, name="conv_3d_lstm_cell", **kwargs):
2155    """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
2156    super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
2157
2158
2159def _conv(args, filter_size, num_features, bias, bias_start=0.0):
2160  """convolution:
2161  Args:
2162    args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
2163    batch x n, Tensors.
2164    filter_size: int tuple of filter height and width.
2165    num_features: int, number of features.
2166    bias_start: starting value to initialize the bias; 0 by default.
2167  Returns:
2168    A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
2169  Raises:
2170    ValueError: if some of the arguments has unspecified or wrong shape.
2171  """
2172
2173  # Calculate the total size of arguments on dimension 1.
2174  total_arg_size_depth = 0
2175  shapes = [a.get_shape().as_list() for a in args]
2176  shape_length = len(shapes[0])
2177  for shape in shapes:
2178    if len(shape) not in [3, 4, 5]:
2179      raise ValueError("Conv Linear expects 3D, 4D "
2180                       "or 5D arguments: %s" % str(shapes))
2181    if len(shape) != len(shapes[0]):
2182      raise ValueError("Conv Linear expects all args "
2183                       "to be of same Dimension: %s" % str(shapes))
2184    else:
2185      total_arg_size_depth += shape[-1]
2186  dtype = [a.dtype for a in args][0]
2187
2188  # determine correct conv operation
2189  if shape_length == 3:
2190    conv_op = nn_ops.conv1d
2191    strides = 1
2192  elif shape_length == 4:
2193    conv_op = nn_ops.conv2d
2194    strides = shape_length * [1]
2195  elif shape_length == 5:
2196    conv_op = nn_ops.conv3d
2197    strides = shape_length * [1]
2198
2199  # Now the computation.
2200  kernel = vs.get_variable(
2201      "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
2202  if len(args) == 1:
2203    res = conv_op(args[0], kernel, strides, padding="SAME")
2204  else:
2205    res = conv_op(
2206        array_ops.concat(axis=shape_length - 1, values=args),
2207        kernel,
2208        strides,
2209        padding="SAME")
2210  if not bias:
2211    return res
2212  bias_term = vs.get_variable(
2213      "biases", [num_features],
2214      dtype=dtype,
2215      initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
2216  return res + bias_term
2217
2218
2219class GLSTMCell(rnn_cell_impl.RNNCell):
2220  """Group LSTM cell (G-LSTM).
2221
2222  The implementation is based on:
2223
2224    https://arxiv.org/abs/1703.10722
2225
2226  O. Kuchaiev and B. Ginsburg
2227  "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
2228  """
2229
2230  def __init__(self,
2231               num_units,
2232               initializer=None,
2233               num_proj=None,
2234               number_of_groups=1,
2235               forget_bias=1.0,
2236               activation=math_ops.tanh,
2237               reuse=None):
2238    """Initialize the parameters of G-LSTM cell.
2239
2240    Args:
2241      num_units: int, The number of units in the G-LSTM cell
2242      initializer: (optional) The initializer to use for the weight and
2243        projection matrices.
2244      num_proj: (optional) int, The output dimensionality for the projection
2245        matrices.  If None, no projection is performed.
2246      number_of_groups: (optional) int, number of groups to use.
2247        If `number_of_groups` is 1, then it should be equivalent to LSTM cell
2248      forget_bias: Biases of the forget gate are initialized by default to 1
2249        in order to reduce the scale of forgetting at the beginning of
2250        the training.
2251      activation: Activation function of the inner states.
2252      reuse: (optional) Python boolean describing whether to reuse variables
2253        in an existing scope.  If not `True`, and the existing scope already
2254        has the given variables, an error is raised.
2255
2256    Raises:
2257      ValueError: If `num_units` or `num_proj` is not divisible by
2258        `number_of_groups`.
2259    """
2260    super(GLSTMCell, self).__init__(_reuse=reuse)
2261    self._num_units = num_units
2262    self._initializer = initializer
2263    self._num_proj = num_proj
2264    self._forget_bias = forget_bias
2265    self._activation = activation
2266    self._number_of_groups = number_of_groups
2267
2268    if self._num_units % self._number_of_groups != 0:
2269      raise ValueError("num_units must be divisible by number_of_groups")
2270    if self._num_proj:
2271      if self._num_proj % self._number_of_groups != 0:
2272        raise ValueError("num_proj must be divisible by number_of_groups")
2273      self._group_shape = [
2274          int(self._num_proj / self._number_of_groups),
2275          int(self._num_units / self._number_of_groups)
2276      ]
2277    else:
2278      self._group_shape = [
2279          int(self._num_units / self._number_of_groups),
2280          int(self._num_units / self._number_of_groups)
2281      ]
2282
2283    if num_proj:
2284      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2285      self._output_size = num_proj
2286    else:
2287      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2288      self._output_size = num_units
2289    self._linear1 = [None] * number_of_groups
2290    self._linear2 = None
2291
2292  @property
2293  def state_size(self):
2294    return self._state_size
2295
2296  @property
2297  def output_size(self):
2298    return self._output_size
2299
2300  def _get_input_for_group(self, inputs, group_id, group_size):
2301    """Slices inputs into groups to prepare for processing by cell's groups
2302
2303    Args:
2304      inputs: cell input or it's previous state,
2305              a Tensor, 2D, [batch x num_units]
2306      group_id: group id, a Scalar, for which to prepare input
2307      group_size: size of the group
2308
2309    Returns:
2310      subset of inputs corresponding to group "group_id",
2311      a Tensor, 2D, [batch x num_units/number_of_groups]
2312    """
2313    return array_ops.slice(
2314        input_=inputs,
2315        begin=[0, group_id * group_size],
2316        size=[self._batch_size, group_size],
2317        name=("GLSTM_group%d_input_generation" % group_id))
2318
2319  def call(self, inputs, state):
2320    """Run one step of G-LSTM.
2321
2322    Args:
2323      inputs: input Tensor, 2D, [batch x num_units].
2324      state: this must be a tuple of state Tensors, both `2-D`,
2325      with column sizes `c_state` and `m_state`.
2326
2327    Returns:
2328      A tuple containing:
2329
2330      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2331        G-LSTM after reading `inputs` when previous state was `state`.
2332        Here output_dim is:
2333           num_proj if num_proj was set,
2334           num_units otherwise.
2335      - LSTMStateTuple representing the new state of G-LSTM cell
2336        after reading `inputs` when the previous state was `state`.
2337
2338    Raises:
2339      ValueError: If input size cannot be inferred from inputs via
2340        static shape inference.
2341    """
2342    (c_prev, m_prev) = state
2343
2344    self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
2345    dtype = inputs.dtype
2346    scope = vs.get_variable_scope()
2347    with vs.variable_scope(scope, initializer=self._initializer):
2348      i_parts = []
2349      j_parts = []
2350      f_parts = []
2351      o_parts = []
2352
2353      for group_id in range(self._number_of_groups):
2354        with vs.variable_scope("group%d" % group_id):
2355          x_g_id = array_ops.concat(
2356              [
2357                  self._get_input_for_group(inputs, group_id,
2358                                            self._group_shape[0]),
2359                  self._get_input_for_group(m_prev, group_id,
2360                                            self._group_shape[0])
2361              ],
2362              axis=1)
2363          linear = self._linear1[group_id]
2364          if linear is None:
2365            linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
2366            self._linear1[group_id] = linear
2367          R_k = linear(x_g_id)  # pylint: disable=invalid-name
2368          i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
2369
2370        i_parts.append(i_k)
2371        j_parts.append(j_k)
2372        f_parts.append(f_k)
2373        o_parts.append(o_k)
2374
2375      bi = vs.get_variable(
2376          name="bias_i",
2377          shape=[self._num_units],
2378          dtype=dtype,
2379          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2380      bj = vs.get_variable(
2381          name="bias_j",
2382          shape=[self._num_units],
2383          dtype=dtype,
2384          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2385      bf = vs.get_variable(
2386          name="bias_f",
2387          shape=[self._num_units],
2388          dtype=dtype,
2389          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2390      bo = vs.get_variable(
2391          name="bias_o",
2392          shape=[self._num_units],
2393          dtype=dtype,
2394          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2395
2396      i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
2397      j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
2398      f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
2399      o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
2400
2401    c = (
2402        math_ops.sigmoid(f + self._forget_bias) * c_prev +
2403        math_ops.sigmoid(i) * math_ops.tanh(j))
2404    m = math_ops.sigmoid(o) * self._activation(c)
2405
2406    if self._num_proj is not None:
2407      with vs.variable_scope("projection"):
2408        if self._linear2 is None:
2409          self._linear2 = _Linear(m, self._num_proj, False)
2410        m = self._linear2(m)
2411
2412    new_state = rnn_cell_impl.LSTMStateTuple(c, m)
2413    return m, new_state
2414
2415
2416class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
2417  """Long short-term memory unit (LSTM) recurrent network cell.
2418
2419  The default non-peephole implementation is based on:
2420
2421    http://www.bioinf.jku.at/publications/older/2604.pdf
2422
2423  S. Hochreiter and J. Schmidhuber.
2424  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
2425
2426  The peephole implementation is based on:
2427
2428    https://research.google.com/pubs/archive/43905.pdf
2429
2430  Hasim Sak, Andrew Senior, and Francoise Beaufays.
2431  "Long short-term memory recurrent neural network architectures for
2432   large scale acoustic modeling." INTERSPEECH, 2014.
2433
2434  The class uses optional peep-hole connections, optional cell clipping, and
2435  an optional projection layer.
2436
2437  Layer normalization implementation is based on:
2438
2439    https://arxiv.org/abs/1607.06450.
2440
2441  "Layer Normalization"
2442  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
2443
2444  and is applied before the internal nonlinearities.
2445
2446  """
2447
2448  def __init__(self,
2449               num_units,
2450               use_peepholes=False,
2451               cell_clip=None,
2452               initializer=None,
2453               num_proj=None,
2454               proj_clip=None,
2455               forget_bias=1.0,
2456               activation=None,
2457               layer_norm=False,
2458               norm_gain=1.0,
2459               norm_shift=0.0,
2460               reuse=None):
2461    """Initialize the parameters for an LSTM cell.
2462
2463    Args:
2464      num_units: int, The number of units in the LSTM cell
2465      use_peepholes: bool, set True to enable diagonal/peephole connections.
2466      cell_clip: (optional) A float value, if provided the cell state is clipped
2467        by this value prior to the cell output activation.
2468      initializer: (optional) The initializer to use for the weight and
2469        projection matrices.
2470      num_proj: (optional) int, The output dimensionality for the projection
2471        matrices.  If None, no projection is performed.
2472      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
2473        provided, then the projected values are clipped elementwise to within
2474        `[-proj_clip, proj_clip]`.
2475      forget_bias: Biases of the forget gate are initialized by default to 1
2476        in order to reduce the scale of forgetting at the beginning of
2477        the training. Must set it manually to `0.0` when restoring from
2478        CudnnLSTM trained checkpoints.
2479      activation: Activation function of the inner states.  Default: `tanh`.
2480      layer_norm: If `True`, layer normalization will be applied.
2481      norm_gain: float, The layer normalization gain initial value. If
2482        `layer_norm` has been set to `False`, this argument will be ignored.
2483      norm_shift: float, The layer normalization shift initial value. If
2484        `layer_norm` has been set to `False`, this argument will be ignored.
2485      reuse: (optional) Python boolean describing whether to reuse variables
2486        in an existing scope.  If not `True`, and the existing scope already has
2487        the given variables, an error is raised.
2488
2489      When restoring from CudnnLSTM-trained checkpoints, must use
2490      CudnnCompatibleLSTMCell instead.
2491    """
2492    super(LayerNormLSTMCell, self).__init__(_reuse=reuse)
2493
2494    self._num_units = num_units
2495    self._use_peepholes = use_peepholes
2496    self._cell_clip = cell_clip
2497    self._initializer = initializer
2498    self._num_proj = num_proj
2499    self._proj_clip = proj_clip
2500    self._forget_bias = forget_bias
2501    self._activation = activation or math_ops.tanh
2502    self._layer_norm = layer_norm
2503    self._norm_gain = norm_gain
2504    self._norm_shift = norm_shift
2505
2506    if num_proj:
2507      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
2508      self._output_size = num_proj
2509    else:
2510      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
2511      self._output_size = num_units
2512
2513  @property
2514  def state_size(self):
2515    return self._state_size
2516
2517  @property
2518  def output_size(self):
2519    return self._output_size
2520
2521  def _linear(self,
2522              args,
2523              output_size,
2524              bias,
2525              bias_initializer=None,
2526              kernel_initializer=None,
2527              layer_norm=False):
2528    """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.
2529
2530    Args:
2531      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2532      output_size: int, second dimension of W[i].
2533      bias: boolean, whether to add a bias term or not.
2534      bias_initializer: starting value to initialize the bias
2535        (default is all zeros).
2536      kernel_initializer: starting value to initialize the weight.
2537      layer_norm: boolean, whether to apply layer normalization.
2538
2539
2540    Returns:
2541      A 2D Tensor with shape [batch x output_size] taking value
2542      sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.
2543
2544    Raises:
2545      ValueError: if some of the arguments has unspecified or wrong shape.
2546    """
2547    if args is None or (nest.is_sequence(args) and not args):
2548      raise ValueError("`args` must be specified")
2549    if not nest.is_sequence(args):
2550      args = [args]
2551
2552    # Calculate the total size of arguments on dimension 1.
2553    total_arg_size = 0
2554    shapes = [a.get_shape() for a in args]
2555    for shape in shapes:
2556      if shape.ndims != 2:
2557        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2558      if shape[1].value is None:
2559        raise ValueError("linear expects shape[1] to be provided for shape %s, "
2560                         "but saw %s" % (shape, shape[1]))
2561      else:
2562        total_arg_size += shape[1].value
2563
2564    dtype = [a.dtype for a in args][0]
2565
2566    # Now the computation.
2567    scope = vs.get_variable_scope()
2568    with vs.variable_scope(scope) as outer_scope:
2569      weights = vs.get_variable(
2570          "kernel", [total_arg_size, output_size],
2571          dtype=dtype,
2572          initializer=kernel_initializer)
2573      if len(args) == 1:
2574        res = math_ops.matmul(args[0], weights)
2575      else:
2576        res = math_ops.matmul(array_ops.concat(args, 1), weights)
2577      if not bias:
2578        return res
2579      with vs.variable_scope(outer_scope) as inner_scope:
2580        inner_scope.set_partitioner(None)
2581        if bias_initializer is None:
2582          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2583        biases = vs.get_variable(
2584            "bias", [output_size], dtype=dtype, initializer=bias_initializer)
2585
2586    if not layer_norm:
2587      res = nn_ops.bias_add(res, biases)
2588
2589    return res
2590
2591  def call(self, inputs, state):
2592    """Run one step of LSTM.
2593
2594    Args:
2595      inputs: input Tensor, 2D, batch x num_units.
2596      state: this must be a tuple of state Tensors,
2597       both `2-D`, with column sizes `c_state` and
2598        `m_state`.
2599
2600    Returns:
2601      A tuple containing:
2602
2603      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2604        LSTM after reading `inputs` when previous state was `state`.
2605        Here output_dim is:
2606           num_proj if num_proj was set,
2607           num_units otherwise.
2608      - Tensor(s) representing the new state of LSTM after reading `inputs` when
2609        the previous state was `state`.  Same type and shape(s) as `state`.
2610
2611    Raises:
2612      ValueError: If input size cannot be inferred from inputs via
2613        static shape inference.
2614    """
2615    sigmoid = math_ops.sigmoid
2616
2617    (c_prev, m_prev) = state
2618
2619    dtype = inputs.dtype
2620    input_size = inputs.get_shape().with_rank(2)[1]
2621    if input_size.value is None:
2622      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
2623    scope = vs.get_variable_scope()
2624    with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
2625
2626      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
2627      lstm_matrix = self._linear(
2628          [inputs, m_prev],
2629          4 * self._num_units,
2630          bias=True,
2631          bias_initializer=None,
2632          layer_norm=self._layer_norm)
2633      i, j, f, o = array_ops.split(
2634          value=lstm_matrix, num_or_size_splits=4, axis=1)
2635
2636      if self._layer_norm:
2637        i = _norm(self._norm_gain, self._norm_shift, i, "input")
2638        j = _norm(self._norm_gain, self._norm_shift, j, "transform")
2639        f = _norm(self._norm_gain, self._norm_shift, f, "forget")
2640        o = _norm(self._norm_gain, self._norm_shift, o, "output")
2641
2642      # Diagonal connections
2643      if self._use_peepholes:
2644        with vs.variable_scope(unit_scope):
2645          w_f_diag = vs.get_variable(
2646              "w_f_diag", shape=[self._num_units], dtype=dtype)
2647          w_i_diag = vs.get_variable(
2648              "w_i_diag", shape=[self._num_units], dtype=dtype)
2649          w_o_diag = vs.get_variable(
2650              "w_o_diag", shape=[self._num_units], dtype=dtype)
2651
2652      if self._use_peepholes:
2653        c = (
2654            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
2655            sigmoid(i + w_i_diag * c_prev) * self._activation(j))
2656      else:
2657        c = (
2658            sigmoid(f + self._forget_bias) * c_prev +
2659            sigmoid(i) * self._activation(j))
2660
2661      if self._layer_norm:
2662        c = _norm(self._norm_gain, self._norm_shift, c, "state")
2663
2664      if self._cell_clip is not None:
2665        # pylint: disable=invalid-unary-operand-type
2666        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
2667        # pylint: enable=invalid-unary-operand-type
2668      if self._use_peepholes:
2669        m = sigmoid(o + w_o_diag * c) * self._activation(c)
2670      else:
2671        m = sigmoid(o) * self._activation(c)
2672
2673      if self._num_proj is not None:
2674        with vs.variable_scope("projection"):
2675          m = self._linear(m, self._num_proj, bias=False)
2676
2677        if self._proj_clip is not None:
2678          # pylint: disable=invalid-unary-operand-type
2679          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
2680          # pylint: enable=invalid-unary-operand-type
2681
2682    new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
2683    return m, new_state
2684
2685
2686class SRUCell(rnn_cell_impl.LayerRNNCell):
2687  """SRU, Simple Recurrent Unit
2688
2689     Implementation based on
2690     Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
2691
2692     This variation of RNN cell is characterized by the simplified data
2693     dependence
2694     between hidden states of two consecutive time steps. Traditionally, hidden
2695     states from a cell at time step t-1 needs to be multiplied with a matrix
2696     W_hh before being fed into the ensuing cell at time step t.
2697     This flavor of RNN replaces the matrix multiplication between h_{t-1}
2698     and W_hh with a pointwise multiplication, resulting in performance
2699     gain.
2700
2701  Args:
2702    num_units: int, The number of units in the SRU cell.
2703    activation: Nonlinearity to use.  Default: `tanh`.
2704    reuse: (optional) Python boolean describing whether to reuse variables
2705      in an existing scope.  If not `True`, and the existing scope already has
2706      the given variables, an error is raised.
2707    name: (optional) String, the name of the layer. Layers with the same name
2708      will share weights, but to avoid mistakes we require reuse=True in such
2709      cases.
2710  """
2711
2712  def __init__(self, num_units, activation=None, reuse=None, name=None):
2713    super(SRUCell, self).__init__(_reuse=reuse, name=name)
2714    self._num_units = num_units
2715    self._activation = activation or math_ops.tanh
2716
2717    # Restrict inputs to be 2-dimensional matrices
2718    self.input_spec = base_layer.InputSpec(ndim=2)
2719
2720  @property
2721  def state_size(self):
2722    return self._num_units
2723
2724  @property
2725  def output_size(self):
2726    return self._num_units
2727
2728  def build(self, inputs_shape):
2729    if inputs_shape[1].value is None:
2730      raise ValueError(
2731          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
2732
2733    input_depth = inputs_shape[1].value
2734
2735    self._kernel = self.add_variable(
2736        rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2737        shape=[input_depth, 4 * self._num_units])
2738
2739    self._bias = self.add_variable(
2740        rnn_cell_impl._BIAS_VARIABLE_NAME,
2741        shape=[2 * self._num_units],
2742        initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
2743
2744    self._built = True
2745
2746  def call(self, inputs, state):
2747    """Simple recurrent unit (SRU) with num_units cells."""
2748
2749    U = math_ops.matmul(inputs, self._kernel)
2750    x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
2751        value=U, num_or_size_splits=4, axis=1)
2752
2753    f_r = math_ops.sigmoid(
2754        nn_ops.bias_add(
2755            array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
2756    f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
2757
2758    c = f * state + (1.0 - f) * x_bar
2759    h = r * self._activation(c) + (1.0 - r) * x_tx
2760
2761    return h, c
2762
2763
2764class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
2765  """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.
2766
2767    The weight-norm implementation is based on:
2768    https://arxiv.org/abs/1602.07868
2769    Tim Salimans, Diederik P. Kingma.
2770    Weight Normalization: A Simple Reparameterization to Accelerate
2771    Training of Deep Neural Networks
2772
2773    The default LSTM implementation based on:
2774    http://www.bioinf.jku.at/publications/older/2604.pdf
2775    S. Hochreiter and J. Schmidhuber.
2776    "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
2777
2778    The class uses optional peephole connections, optional cell clipping
2779    and an optional projection layer.
2780
2781    The optional peephole implementation is based on:
2782    https://research.google.com/pubs/archive/43905.pdf
2783    Hasim Sak, Andrew Senior, and Francoise Beaufays.
2784    "Long short-term memory recurrent neural network architectures for
2785    large scale acoustic modeling." INTERSPEECH, 2014.
2786  """
2787
2788  def __init__(self,
2789               num_units,
2790               norm=True,
2791               use_peepholes=False,
2792               cell_clip=None,
2793               initializer=None,
2794               num_proj=None,
2795               proj_clip=None,
2796               forget_bias=1,
2797               activation=None,
2798               reuse=None):
2799    """Initialize the parameters of a weight-normalized LSTM cell.
2800
2801    Args:
2802      num_units: int, The number of units in the LSTM cell
2803      norm: If `True`, apply normalization to the weight matrices. If False,
2804        the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
2805      use_peepholes: bool, set `True` to enable diagonal/peephole connections.
2806      cell_clip: (optional) A float value, if provided the cell state is clipped
2807        by this value prior to the cell output activation.
2808      initializer: (optional) The initializer to use for the weight matrices.
2809      num_proj: (optional) int, The output dimensionality for the projection
2810        matrices.  If None, no projection is performed.
2811      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
2812        provided, then the projected values are clipped elementwise to within
2813        `[-proj_clip, proj_clip]`.
2814      forget_bias: Biases of the forget gate are initialized by default to 1
2815        in order to reduce the scale of forgetting at the beginning of
2816        the training.
2817      activation: Activation function of the inner states.  Default: `tanh`.
2818      reuse: (optional) Python boolean describing whether to reuse variables
2819        in an existing scope.  If not `True`, and the existing scope already has
2820        the given variables, an error is raised.
2821    """
2822    super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
2823
2824    self._scope = "wn_lstm_cell"
2825    self._num_units = num_units
2826    self._norm = norm
2827    self._initializer = initializer
2828    self._use_peepholes = use_peepholes
2829    self._cell_clip = cell_clip
2830    self._num_proj = num_proj
2831    self._proj_clip = proj_clip
2832    self._activation = activation or math_ops.tanh
2833    self._forget_bias = forget_bias
2834
2835    self._weights_variable_name = "kernel"
2836    self._bias_variable_name = "bias"
2837
2838    if num_proj:
2839      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2840      self._output_size = num_proj
2841    else:
2842      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2843      self._output_size = num_units
2844
2845  @property
2846  def state_size(self):
2847    return self._state_size
2848
2849  @property
2850  def output_size(self):
2851    return self._output_size
2852
2853  def _normalize(self, weight, name):
2854    """Apply weight normalization.
2855
2856    Args:
2857      weight: a 2D tensor with known number of columns.
2858      name: string, variable name for the normalizer.
2859    Returns:
2860      A tensor with the same shape as `weight`.
2861    """
2862
2863    output_size = weight.get_shape().as_list()[1]
2864    g = vs.get_variable(name, [output_size], dtype=weight.dtype)
2865    return nn_impl.l2_normalize(weight, dim=0) * g
2866
2867  def _linear(self,
2868              args,
2869              output_size,
2870              norm,
2871              bias,
2872              bias_initializer=None,
2873              kernel_initializer=None):
2874    """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
2875
2876    Args:
2877      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2878      output_size: int, second dimension of W[i].
2879      bias: boolean, whether to add a bias term or not.
2880      bias_initializer: starting value to initialize the bias
2881        (default is all zeros).
2882      kernel_initializer: starting value to initialize the weight.
2883
2884    Returns:
2885      A 2D Tensor with shape [batch x output_size] equal to
2886      sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
2887
2888    Raises:
2889      ValueError: if some of the arguments has unspecified or wrong shape.
2890    """
2891    if args is None or (nest.is_sequence(args) and not args):
2892      raise ValueError("`args` must be specified")
2893    if not nest.is_sequence(args):
2894      args = [args]
2895
2896    # Calculate the total size of arguments on dimension 1.
2897    total_arg_size = 0
2898    shapes = [a.get_shape() for a in args]
2899    for shape in shapes:
2900      if shape.ndims != 2:
2901        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2902      if shape[1].value is None:
2903        raise ValueError("linear expects shape[1] to be provided for shape %s, "
2904                         "but saw %s" % (shape, shape[1]))
2905      else:
2906        total_arg_size += shape[1].value
2907
2908    dtype = [a.dtype for a in args][0]
2909
2910    # Now the computation.
2911    scope = vs.get_variable_scope()
2912    with vs.variable_scope(scope) as outer_scope:
2913      weights = vs.get_variable(
2914          self._weights_variable_name, [total_arg_size, output_size],
2915          dtype=dtype,
2916          initializer=kernel_initializer)
2917      if norm:
2918        wn = []
2919        st = 0
2920        with ops.control_dependencies(None):
2921          for i in range(len(args)):
2922            en = st + shapes[i][1].value
2923            wn.append(
2924                self._normalize(weights[st:en, :], name="norm_{}".format(i)))
2925            st = en
2926
2927          weights = array_ops.concat(wn, axis=0)
2928
2929      if len(args) == 1:
2930        res = math_ops.matmul(args[0], weights)
2931      else:
2932        res = math_ops.matmul(array_ops.concat(args, 1), weights)
2933      if not bias:
2934        return res
2935
2936      with vs.variable_scope(outer_scope) as inner_scope:
2937        inner_scope.set_partitioner(None)
2938        if bias_initializer is None:
2939          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2940
2941        biases = vs.get_variable(
2942            self._bias_variable_name, [output_size],
2943            dtype=dtype,
2944            initializer=bias_initializer)
2945
2946      return nn_ops.bias_add(res, biases)
2947
2948  def call(self, inputs, state):
2949    """Run one step of LSTM.
2950
2951    Args:
2952      inputs: input Tensor, 2D, batch x num_units.
2953      state: A tuple of state Tensors, both `2-D`, with column sizes
2954       `c_state` and `m_state`.
2955
2956    Returns:
2957      A tuple containing:
2958
2959      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2960        LSTM after reading `inputs` when previous state was `state`.
2961        Here output_dim is:
2962           num_proj if num_proj was set,
2963           num_units otherwise.
2964      - Tensor(s) representing the new state of LSTM after reading `inputs` when
2965        the previous state was `state`.  Same type and shape(s) as `state`.
2966
2967    Raises:
2968      ValueError: If input size cannot be inferred from inputs via
2969        static shape inference.
2970    """
2971    dtype = inputs.dtype
2972    num_units = self._num_units
2973    sigmoid = math_ops.sigmoid
2974    c, h = state
2975
2976    input_size = inputs.get_shape().with_rank(2)[1]
2977    if input_size.value is None:
2978      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
2979
2980    with vs.variable_scope(self._scope, initializer=self._initializer):
2981
2982      concat = self._linear(
2983          [inputs, h], 4 * num_units, norm=self._norm, bias=True)
2984
2985      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
2986      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
2987
2988      if self._use_peepholes:
2989        w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
2990        w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
2991        w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
2992
2993        new_c = (
2994            c * sigmoid(f + self._forget_bias + w_f_diag * c) +
2995            sigmoid(i + w_i_diag * c) * self._activation(j))
2996      else:
2997        new_c = (
2998            c * sigmoid(f + self._forget_bias) +
2999            sigmoid(i) * self._activation(j))
3000
3001      if self._cell_clip is not None:
3002        # pylint: disable=invalid-unary-operand-type
3003        new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
3004        # pylint: enable=invalid-unary-operand-type
3005      if self._use_peepholes:
3006        new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
3007      else:
3008        new_h = sigmoid(o) * self._activation(new_c)
3009
3010      if self._num_proj is not None:
3011        with vs.variable_scope("projection"):
3012          new_h = self._linear(
3013              new_h, self._num_proj, norm=self._norm, bias=False)
3014
3015        if self._proj_clip is not None:
3016          # pylint: disable=invalid-unary-operand-type
3017          new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
3018                                         self._proj_clip)
3019          # pylint: enable=invalid-unary-operand-type
3020
3021      new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
3022      return new_h, new_state
3023