• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for constructing RNN Cells."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import math
22
23from tensorflow.contrib.compiler import jit
24from tensorflow.contrib.layers.python.layers import layers
25from tensorflow.contrib.rnn.python.ops import core_rnn_cell
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import op_def_registry
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.keras import activations
32from tensorflow.python.keras import initializers
33from tensorflow.python.keras.engine import input_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import clip_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_array_ops
38from tensorflow.python.ops import init_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import nn_impl  # pylint: disable=unused-import
41from tensorflow.python.ops import nn_ops
42from tensorflow.python.ops import partitioned_variables  # pylint: disable=unused-import
43from tensorflow.python.ops import random_ops
44from tensorflow.python.ops import rnn_cell_impl
45from tensorflow.python.ops import variable_scope as vs
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.util import nest
48
49
50def _get_concat_variable(name, shape, dtype, num_shards):
51  """Get a sharded variable concatenated into one tensor."""
52  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
53  if len(sharded_variable) == 1:
54    return sharded_variable[0]
55
56  concat_name = name + "/concat"
57  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
58  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
59    if value.name == concat_full_name:
60      return value
61
62  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
63  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
64  return concat_variable
65
66
67def _get_sharded_variable(name, shape, dtype, num_shards):
68  """Get a list of sharded variables with the given dtype."""
69  if num_shards > shape[0]:
70    raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
71                                                                   num_shards))
72  unit_shard_size = int(math.floor(shape[0] / num_shards))
73  remaining_rows = shape[0] - unit_shard_size * num_shards
74
75  shards = []
76  for i in range(num_shards):
77    current_size = unit_shard_size
78    if i < remaining_rows:
79      current_size += 1
80    shards.append(
81        vs.get_variable(
82            name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
83  return shards
84
85
86def _norm(g, b, inp, scope):
87  shape = inp.get_shape()[-1:]
88  gamma_init = init_ops.constant_initializer(g)
89  beta_init = init_ops.constant_initializer(b)
90  with vs.variable_scope(scope):
91    # Initialize beta and gamma for use by layer_norm.
92    vs.get_variable("gamma", shape=shape, initializer=gamma_init)
93    vs.get_variable("beta", shape=shape, initializer=beta_init)
94  normalized = layers.layer_norm(inp, reuse=True, scope=scope)
95  return normalized
96
97
98class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
99  """Long short-term memory unit (LSTM) recurrent network cell.
100
101  The default non-peephole implementation is based on:
102
103    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
104
105  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
106  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
107
108  The peephole implementation is based on:
109
110    https://research.google.com/pubs/archive/43905.pdf
111
112  Hasim Sak, Andrew Senior, and Francoise Beaufays.
113  "Long short-term memory recurrent neural network architectures for
114   large scale acoustic modeling." INTERSPEECH, 2014.
115
116  The coupling of input and forget gate is based on:
117
118    http://arxiv.org/pdf/1503.04069.pdf
119
120  Greff et al. "LSTM: A Search Space Odyssey"
121
122  The class uses optional peep-hole connections, and an optional projection
123  layer.
124  Layer normalization implementation is based on:
125
126    https://arxiv.org/abs/1607.06450.
127
128  "Layer Normalization"
129  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
130
131  and is applied before the internal nonlinearities.
132
133  """
134
135  def __init__(self,
136               num_units,
137               use_peepholes=False,
138               initializer=None,
139               num_proj=None,
140               proj_clip=None,
141               num_unit_shards=1,
142               num_proj_shards=1,
143               forget_bias=1.0,
144               state_is_tuple=True,
145               activation=math_ops.tanh,
146               reuse=None,
147               layer_norm=False,
148               norm_gain=1.0,
149               norm_shift=0.0):
150    """Initialize the parameters for an LSTM cell.
151
152    Args:
153      num_units: int, The number of units in the LSTM cell
154      use_peepholes: bool, set True to enable diagonal/peephole connections.
155      initializer: (optional) The initializer to use for the weight and
156        projection matrices.
157      num_proj: (optional) int, The output dimensionality for the projection
158        matrices.  If None, no projection is performed.
159      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
160      provided, then the projected values are clipped elementwise to within
161      `[-proj_clip, proj_clip]`.
162      num_unit_shards: How to split the weight matrix.  If >1, the weight
163        matrix is stored across num_unit_shards.
164      num_proj_shards: How to split the projection matrix.  If >1, the
165        projection matrix is stored across num_proj_shards.
166      forget_bias: Biases of the forget gate are initialized by default to 1
167        in order to reduce the scale of forgetting at the beginning of
168        the training.
169      state_is_tuple: If True, accepted and returned states are 2-tuples of
170        the `c_state` and `m_state`.  By default (False), they are concatenated
171        along the column axis.  This default behavior will soon be deprecated.
172      activation: Activation function of the inner states.
173      reuse: (optional) Python boolean describing whether to reuse variables
174        in an existing scope.  If not `True`, and the existing scope already has
175        the given variables, an error is raised.
176      layer_norm: If `True`, layer normalization will be applied.
177      norm_gain: float, The layer normalization gain initial value. If
178        `layer_norm` has been set to `False`, this argument will be ignored.
179      norm_shift: float, The layer normalization shift initial value. If
180        `layer_norm` has been set to `False`, this argument will be ignored.
181    """
182    super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
183    if not state_is_tuple:
184      logging.warn("%s: Using a concatenated state is slower and will soon be "
185                   "deprecated.  Use state_is_tuple=True.", self)
186    self._num_units = num_units
187    self._use_peepholes = use_peepholes
188    self._initializer = initializer
189    self._num_proj = num_proj
190    self._proj_clip = proj_clip
191    self._num_unit_shards = num_unit_shards
192    self._num_proj_shards = num_proj_shards
193    self._forget_bias = forget_bias
194    self._state_is_tuple = state_is_tuple
195    self._activation = activation
196    self._reuse = reuse
197    self._layer_norm = layer_norm
198    self._norm_gain = norm_gain
199    self._norm_shift = norm_shift
200
201    if num_proj:
202      self._state_size = (
203          rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
204          if state_is_tuple else num_units + num_proj)
205      self._output_size = num_proj
206    else:
207      self._state_size = (
208          rnn_cell_impl.LSTMStateTuple(num_units, num_units)
209          if state_is_tuple else 2 * num_units)
210      self._output_size = num_units
211
212  @property
213  def state_size(self):
214    return self._state_size
215
216  @property
217  def output_size(self):
218    return self._output_size
219
220  def call(self, inputs, state):
221    """Run one step of LSTM.
222
223    Args:
224      inputs: input Tensor, 2D, batch x num_units.
225      state: if `state_is_tuple` is False, this must be a state Tensor,
226        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
227        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
228        `m_state`.
229
230    Returns:
231      A tuple containing:
232      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
233        LSTM after reading `inputs` when previous state was `state`.
234        Here output_dim is:
235           num_proj if num_proj was set,
236           num_units otherwise.
237      - Tensor(s) representing the new state of LSTM after reading `inputs` when
238        the previous state was `state`.  Same type and shape(s) as `state`.
239
240    Raises:
241      ValueError: If input size cannot be inferred from inputs via
242        static shape inference.
243    """
244    sigmoid = math_ops.sigmoid
245
246    num_proj = self._num_units if self._num_proj is None else self._num_proj
247
248    if self._state_is_tuple:
249      (c_prev, m_prev) = state
250    else:
251      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
252      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
253
254    dtype = inputs.dtype
255    input_size = inputs.get_shape().with_rank(2).dims[1]
256    if input_size.value is None:
257      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
258    concat_w = _get_concat_variable(
259        "W",
260        [input_size.value + num_proj, 3 * self._num_units],
261        dtype,
262        self._num_unit_shards)
263
264    b = vs.get_variable(
265        "B",
266        shape=[3 * self._num_units],
267        initializer=init_ops.zeros_initializer(),
268        dtype=dtype)
269
270    # j = new_input, f = forget_gate, o = output_gate
271    cell_inputs = array_ops.concat([inputs, m_prev], 1)
272    lstm_matrix = math_ops.matmul(cell_inputs, concat_w)
273
274    # If layer nomalization is applied, do not add bias
275    if not self._layer_norm:
276      lstm_matrix = nn_ops.bias_add(lstm_matrix, b)
277
278    j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
279
280    # Apply layer normalization
281    if self._layer_norm:
282      j = _norm(self._norm_gain, self._norm_shift, j, "transform")
283      f = _norm(self._norm_gain, self._norm_shift, f, "forget")
284      o = _norm(self._norm_gain, self._norm_shift, o, "output")
285
286    # Diagonal connections
287    if self._use_peepholes:
288      w_f_diag = vs.get_variable(
289          "W_F_diag", shape=[self._num_units], dtype=dtype)
290      w_o_diag = vs.get_variable(
291          "W_O_diag", shape=[self._num_units], dtype=dtype)
292
293    if self._use_peepholes:
294      f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
295    else:
296      f_act = sigmoid(f + self._forget_bias)
297    c = (f_act * c_prev + (1 - f_act) * self._activation(j))
298
299    # Apply layer normalization
300    if self._layer_norm:
301      c = _norm(self._norm_gain, self._norm_shift, c, "state")
302
303    if self._use_peepholes:
304      m = sigmoid(o + w_o_diag * c) * self._activation(c)
305    else:
306      m = sigmoid(o) * self._activation(c)
307
308    if self._num_proj is not None:
309      concat_w_proj = _get_concat_variable("W_P",
310                                           [self._num_units, self._num_proj],
311                                           dtype, self._num_proj_shards)
312
313      m = math_ops.matmul(m, concat_w_proj)
314      if self._proj_clip is not None:
315        # pylint: disable=invalid-unary-operand-type
316        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
317        # pylint: enable=invalid-unary-operand-type
318
319    new_state = (
320        rnn_cell_impl.LSTMStateTuple(c, m)
321        if self._state_is_tuple else array_ops.concat([c, m], 1))
322    return m, new_state
323
324
325class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
326  """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
327
328  This implementation is based on:
329
330    Tara N. Sainath and Bo Li
331    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
332    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
333
334  It uses peep-hole connections and optional cell clipping.
335  """
336
337  def __init__(self,
338               num_units,
339               use_peepholes=False,
340               cell_clip=None,
341               initializer=None,
342               num_unit_shards=1,
343               forget_bias=1.0,
344               feature_size=None,
345               frequency_skip=1,
346               reuse=None):
347    """Initialize the parameters for an LSTM cell.
348
349    Args:
350      num_units: int, The number of units in the LSTM cell
351      use_peepholes: bool, set True to enable diagonal/peephole connections.
352      cell_clip: (optional) A float value, if provided the cell state is clipped
353        by this value prior to the cell output activation.
354      initializer: (optional) The initializer to use for the weight and
355        projection matrices.
356      num_unit_shards: int, How to split the weight matrix.  If >1, the weight
357        matrix is stored across num_unit_shards.
358      forget_bias: float, Biases of the forget gate are initialized by default
359        to 1 in order to reduce the scale of forgetting at the beginning
360        of the training.
361      feature_size: int, The size of the input feature the LSTM spans over.
362      frequency_skip: int, The amount the LSTM filter is shifted by in
363        frequency.
364      reuse: (optional) Python boolean describing whether to reuse variables
365        in an existing scope.  If not `True`, and the existing scope already has
366        the given variables, an error is raised.
367    """
368    super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
369    self._num_units = num_units
370    self._use_peepholes = use_peepholes
371    self._cell_clip = cell_clip
372    self._initializer = initializer
373    self._num_unit_shards = num_unit_shards
374    self._forget_bias = forget_bias
375    self._feature_size = feature_size
376    self._frequency_skip = frequency_skip
377    self._state_size = 2 * num_units
378    self._output_size = num_units
379    self._reuse = reuse
380
381  @property
382  def output_size(self):
383    return self._output_size
384
385  @property
386  def state_size(self):
387    return self._state_size
388
389  def call(self, inputs, state):
390    """Run one step of LSTM.
391
392    Args:
393      inputs: input Tensor, 2D, batch x num_units.
394      state: state Tensor, 2D, batch x state_size.
395
396    Returns:
397      A tuple containing:
398      - A 2D, batch x output_dim, Tensor representing the output of the LSTM
399        after reading "inputs" when previous state was "state".
400        Here output_dim is num_units.
401      - A 2D, batch x state_size, Tensor representing the new state of LSTM
402        after reading "inputs" when previous state was "state".
403    Raises:
404      ValueError: if an input_size was specified and the provided inputs have
405        a different dimension.
406    """
407    sigmoid = math_ops.sigmoid
408    tanh = math_ops.tanh
409
410    freq_inputs = self._make_tf_features(inputs)
411    dtype = inputs.dtype
412    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
413
414    concat_w = _get_concat_variable(
415        "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
416        dtype, self._num_unit_shards)
417
418    b = vs.get_variable(
419        "B",
420        shape=[4 * self._num_units],
421        initializer=init_ops.zeros_initializer(),
422        dtype=dtype)
423
424    # Diagonal connections
425    if self._use_peepholes:
426      w_f_diag = vs.get_variable(
427          "W_F_diag", shape=[self._num_units], dtype=dtype)
428      w_i_diag = vs.get_variable(
429          "W_I_diag", shape=[self._num_units], dtype=dtype)
430      w_o_diag = vs.get_variable(
431          "W_O_diag", shape=[self._num_units], dtype=dtype)
432
433    # initialize the first freq state to be zero
434    m_prev_freq = array_ops.zeros(
435        [inputs.shape.dims[0].value or inputs.get_shape()[0], self._num_units],
436        dtype)
437    for fq in range(len(freq_inputs)):
438      c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
439                               [-1, self._num_units])
440      m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
441                               [-1, self._num_units])
442      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
443      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
444      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
445      i, j, f, o = array_ops.split(
446          value=lstm_matrix, num_or_size_splits=4, axis=1)
447
448      if self._use_peepholes:
449        c = (
450            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
451            sigmoid(i + w_i_diag * c_prev) * tanh(j))
452      else:
453        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
454
455      if self._cell_clip is not None:
456        # pylint: disable=invalid-unary-operand-type
457        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
458        # pylint: enable=invalid-unary-operand-type
459
460      if self._use_peepholes:
461        m = sigmoid(o + w_o_diag * c) * tanh(c)
462      else:
463        m = sigmoid(o) * tanh(c)
464      m_prev_freq = m
465      if fq == 0:
466        state_out = array_ops.concat([c, m], 1)
467        m_out = m
468      else:
469        state_out = array_ops.concat([state_out, c, m], 1)
470        m_out = array_ops.concat([m_out, m], 1)
471    return m_out, state_out
472
473  def _make_tf_features(self, input_feat):
474    """Make the frequency features.
475
476    Args:
477      input_feat: input Tensor, 2D, batch x num_units.
478
479    Returns:
480      A list of frequency features, with each element containing:
481      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
482        for that frequency index. Here output_dim is feature_size.
483    Raises:
484      ValueError: if input_size cannot be inferred from static shape inference.
485    """
486    input_size = input_feat.get_shape().with_rank(2).dims[-1].value
487    if input_size is None:
488      raise ValueError("Cannot infer input_size from static shape inference.")
489    num_feats = int(
490        (input_size - self._feature_size) / (self._frequency_skip)) + 1
491    freq_inputs = []
492    for f in range(num_feats):
493      cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
494                                  [-1, self._feature_size])
495      freq_inputs.append(cur_input)
496    return freq_inputs
497
498
499class GridLSTMCell(rnn_cell_impl.RNNCell):
500  """Grid Long short-term memory unit (LSTM) recurrent network cell.
501
502  The default is based on:
503    Nal Kalchbrenner, Ivo Danihelka and Alex Graves
504    "Grid Long Short-Term Memory," Proc. ICLR 2016.
505    http://arxiv.org/abs/1507.01526
506
507  When peephole connections are used, the implementation is based on:
508    Tara N. Sainath and Bo Li
509    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
510    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
511
512  The code uses optional peephole connections, shared_weights and cell clipping.
513  """
514
515  def __init__(self,
516               num_units,
517               use_peepholes=False,
518               share_time_frequency_weights=False,
519               cell_clip=None,
520               initializer=None,
521               num_unit_shards=1,
522               forget_bias=1.0,
523               feature_size=None,
524               frequency_skip=None,
525               num_frequency_blocks=None,
526               start_freqindex_list=None,
527               end_freqindex_list=None,
528               couple_input_forget_gates=False,
529               state_is_tuple=True,
530               reuse=None):
531    """Initialize the parameters for an LSTM cell.
532
533    Args:
534      num_units: int, The number of units in the LSTM cell
535      use_peepholes: (optional) bool, default False. Set True to enable
536        diagonal/peephole connections.
537      share_time_frequency_weights: (optional) bool, default False. Set True to
538        enable shared cell weights between time and frequency LSTMs.
539      cell_clip: (optional) A float value, default None, if provided the cell
540        state is clipped by this value prior to the cell output activation.
541      initializer: (optional) The initializer to use for the weight and
542        projection matrices, default None.
543      num_unit_shards: (optional) int, default 1, How to split the weight
544        matrix. If > 1, the weight matrix is stored across num_unit_shards.
545      forget_bias: (optional) float, default 1.0, The initial bias of the
546        forget gates, used to reduce the scale of forgetting at the beginning
547        of the training.
548      feature_size: (optional) int, default None, The size of the input feature
549        the LSTM spans over.
550      frequency_skip: (optional) int, default None, The amount the LSTM filter
551        is shifted by in frequency.
552      num_frequency_blocks: [required] A list of frequency blocks needed to
553        cover the whole input feature splitting defined by start_freqindex_list
554        and end_freqindex_list.
555      start_freqindex_list: [optional], list of ints, default None,  The
556        starting frequency index for each frequency block.
557      end_freqindex_list: [optional], list of ints, default None. The ending
558        frequency index for each frequency block.
559      couple_input_forget_gates: (optional) bool, default False, Whether to
560        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
561        model parameters and computation cost.
562      state_is_tuple: If True, accepted and returned states are 2-tuples of
563        the `c_state` and `m_state`.  By default (False), they are concatenated
564        along the column axis.  This default behavior will soon be deprecated.
565      reuse: (optional) Python boolean describing whether to reuse variables
566        in an existing scope.  If not `True`, and the existing scope already has
567        the given variables, an error is raised.
568    Raises:
569      ValueError: if the num_frequency_blocks list is not specified
570    """
571    super(GridLSTMCell, self).__init__(_reuse=reuse)
572    if not state_is_tuple:
573      logging.warn("%s: Using a concatenated state is slower and will soon be "
574                   "deprecated.  Use state_is_tuple=True.", self)
575    self._num_units = num_units
576    self._use_peepholes = use_peepholes
577    self._share_time_frequency_weights = share_time_frequency_weights
578    self._couple_input_forget_gates = couple_input_forget_gates
579    self._state_is_tuple = state_is_tuple
580    self._cell_clip = cell_clip
581    self._initializer = initializer
582    self._num_unit_shards = num_unit_shards
583    self._forget_bias = forget_bias
584    self._feature_size = feature_size
585    self._frequency_skip = frequency_skip
586    self._start_freqindex_list = start_freqindex_list
587    self._end_freqindex_list = end_freqindex_list
588    self._num_frequency_blocks = num_frequency_blocks
589    self._total_blocks = 0
590    self._reuse = reuse
591    if self._num_frequency_blocks is None:
592      raise ValueError("Must specify num_frequency_blocks")
593
594    for block_index in range(len(self._num_frequency_blocks)):
595      self._total_blocks += int(self._num_frequency_blocks[block_index])
596    if state_is_tuple:
597      state_names = ""
598      for block_index in range(len(self._num_frequency_blocks)):
599        for freq_index in range(self._num_frequency_blocks[block_index]):
600          name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
601          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
602      self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
603                                                      state_names.strip(","))
604      self._state_size = self._state_tuple_type(*(
605          [num_units, num_units] * self._total_blocks))
606    else:
607      self._state_tuple_type = None
608      self._state_size = num_units * self._total_blocks * 2
609    self._output_size = num_units * self._total_blocks * 2
610
611  @property
612  def output_size(self):
613    return self._output_size
614
615  @property
616  def state_size(self):
617    return self._state_size
618
619  @property
620  def state_tuple_type(self):
621    return self._state_tuple_type
622
623  def call(self, inputs, state):
624    """Run one step of LSTM.
625
626    Args:
627      inputs: input Tensor, 2D, [batch, feature_size].
628      state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
629        flag self._state_is_tuple.
630
631    Returns:
632      A tuple containing:
633      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
634        after reading "inputs" when previous state was "state".
635        Here output_dim is num_units.
636      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
637        after reading "inputs" when previous state was "state".
638    Raises:
639      ValueError: if an input_size was specified and the provided inputs have
640        a different dimension.
641    """
642    batch_size = tensor_shape.dimension_value(
643        inputs.shape[0]) or array_ops.shape(inputs)[0]
644    freq_inputs = self._make_tf_features(inputs)
645    m_out_lst = []
646    state_out_lst = []
647    for block in range(len(freq_inputs)):
648      m_out_lst_current, state_out_lst_current = self._compute(
649          freq_inputs[block],
650          block,
651          state,
652          batch_size,
653          state_is_tuple=self._state_is_tuple)
654      m_out_lst.extend(m_out_lst_current)
655      state_out_lst.extend(state_out_lst_current)
656    if self._state_is_tuple:
657      state_out = self._state_tuple_type(*state_out_lst)
658    else:
659      state_out = array_ops.concat(state_out_lst, 1)
660    m_out = array_ops.concat(m_out_lst, 1)
661    return m_out, state_out
662
663  def _compute(self,
664               freq_inputs,
665               block,
666               state,
667               batch_size,
668               state_prefix="state",
669               state_is_tuple=True):
670    """Run the actual computation of one step LSTM.
671
672    Args:
673      freq_inputs: list of Tensors, 2D, [batch, feature_size].
674      block: int, current frequency block index to process.
675      state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
676        the flag state_is_tuple.
677      batch_size: int32, batch size.
678      state_prefix: (optional) string, name prefix for states, defaults to
679        "state".
680      state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
681
682    Returns:
683      A tuple, containing:
684      - A list of [batch, output_dim] Tensors, representing the output of the
685        LSTM given the inputs and state.
686      - A list of [batch, state_size] Tensors, representing the LSTM state
687        values given the inputs and previous state.
688    """
689    sigmoid = math_ops.sigmoid
690    tanh = math_ops.tanh
691    num_gates = 3 if self._couple_input_forget_gates else 4
692    dtype = freq_inputs[0].dtype
693    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
694
695    concat_w_f = _get_concat_variable(
696        "W_f_%d" % block,
697        [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
698        dtype, self._num_unit_shards)
699    b_f = vs.get_variable(
700        "B_f_%d" % block,
701        shape=[num_gates * self._num_units],
702        initializer=init_ops.zeros_initializer(),
703        dtype=dtype)
704    if not self._share_time_frequency_weights:
705      concat_w_t = _get_concat_variable("W_t_%d" % block, [
706          actual_input_size + 2 * self._num_units, num_gates * self._num_units
707      ], dtype, self._num_unit_shards)
708      b_t = vs.get_variable(
709          "B_t_%d" % block,
710          shape=[num_gates * self._num_units],
711          initializer=init_ops.zeros_initializer(),
712          dtype=dtype)
713
714    if self._use_peepholes:
715      # Diagonal connections
716      if not self._couple_input_forget_gates:
717        w_f_diag_freqf = vs.get_variable(
718            "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
719        w_f_diag_freqt = vs.get_variable(
720            "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
721      w_i_diag_freqf = vs.get_variable(
722          "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
723      w_i_diag_freqt = vs.get_variable(
724          "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
725      w_o_diag_freqf = vs.get_variable(
726          "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
727      w_o_diag_freqt = vs.get_variable(
728          "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
729      if not self._share_time_frequency_weights:
730        if not self._couple_input_forget_gates:
731          w_f_diag_timef = vs.get_variable(
732              "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
733          w_f_diag_timet = vs.get_variable(
734              "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
735        w_i_diag_timef = vs.get_variable(
736            "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
737        w_i_diag_timet = vs.get_variable(
738            "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
739        w_o_diag_timef = vs.get_variable(
740            "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
741        w_o_diag_timet = vs.get_variable(
742            "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
743
744    # initialize the first freq state to be zero
745    m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
746    c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
747    for freq_index in range(len(freq_inputs)):
748      if state_is_tuple:
749        name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
750        c_prev_time = getattr(state, name_prefix + "_c")
751        m_prev_time = getattr(state, name_prefix + "_m")
752      else:
753        c_prev_time = array_ops.slice(
754            state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
755        m_prev_time = array_ops.slice(
756            state, [0, (2 * freq_index + 1) * self._num_units],
757            [-1, self._num_units])
758
759      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
760      cell_inputs = array_ops.concat(
761          [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
762
763      # F-LSTM
764      lstm_matrix_freq = nn_ops.bias_add(
765          math_ops.matmul(cell_inputs, concat_w_f), b_f)
766      if self._couple_input_forget_gates:
767        i_freq, j_freq, o_freq = array_ops.split(
768            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
769        f_freq = None
770      else:
771        i_freq, j_freq, f_freq, o_freq = array_ops.split(
772            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
773      # T-LSTM
774      if self._share_time_frequency_weights:
775        i_time = i_freq
776        j_time = j_freq
777        f_time = f_freq
778        o_time = o_freq
779      else:
780        lstm_matrix_time = nn_ops.bias_add(
781            math_ops.matmul(cell_inputs, concat_w_t), b_t)
782        if self._couple_input_forget_gates:
783          i_time, j_time, o_time = array_ops.split(
784              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
785          f_time = None
786        else:
787          i_time, j_time, f_time, o_time = array_ops.split(
788              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
789
790      # F-LSTM c_freq
791      # input gate activations
792      if self._use_peepholes:
793        i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
794                           w_i_diag_freqt * c_prev_time)
795      else:
796        i_freq_g = sigmoid(i_freq)
797      # forget gate activations
798      if self._couple_input_forget_gates:
799        f_freq_g = 1.0 - i_freq_g
800      else:
801        if self._use_peepholes:
802          f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
803                             c_prev_freq + w_f_diag_freqt * c_prev_time)
804        else:
805          f_freq_g = sigmoid(f_freq + self._forget_bias)
806      # cell state
807      c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
808      if self._cell_clip is not None:
809        # pylint: disable=invalid-unary-operand-type
810        c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
811                                        self._cell_clip)
812        # pylint: enable=invalid-unary-operand-type
813
814      # T-LSTM c_freq
815      # input gate activations
816      if self._use_peepholes:
817        if self._share_time_frequency_weights:
818          i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
819                             w_i_diag_freqt * c_prev_time)
820        else:
821          i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
822                             w_i_diag_timet * c_prev_time)
823      else:
824        i_time_g = sigmoid(i_time)
825      # forget gate activations
826      if self._couple_input_forget_gates:
827        f_time_g = 1.0 - i_time_g
828      else:
829        if self._use_peepholes:
830          if self._share_time_frequency_weights:
831            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
832                               c_prev_freq + w_f_diag_freqt * c_prev_time)
833          else:
834            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
835                               c_prev_freq + w_f_diag_timet * c_prev_time)
836        else:
837          f_time_g = sigmoid(f_time + self._forget_bias)
838      # cell state
839      c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
840      if self._cell_clip is not None:
841        # pylint: disable=invalid-unary-operand-type
842        c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
843                                        self._cell_clip)
844        # pylint: enable=invalid-unary-operand-type
845
846      # F-LSTM m_freq
847      if self._use_peepholes:
848        m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
849                         w_o_diag_freqt * c_time) * tanh(c_freq)
850      else:
851        m_freq = sigmoid(o_freq) * tanh(c_freq)
852
853      # T-LSTM m_time
854      if self._use_peepholes:
855        if self._share_time_frequency_weights:
856          m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
857                           w_o_diag_freqt * c_time) * tanh(c_time)
858        else:
859          m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
860                           w_o_diag_timet * c_time) * tanh(c_time)
861      else:
862        m_time = sigmoid(o_time) * tanh(c_time)
863
864      m_prev_freq = m_freq
865      c_prev_freq = c_freq
866      # Concatenate the outputs for T-LSTM and F-LSTM for each shift
867      if freq_index == 0:
868        state_out_lst = [c_time, m_time]
869        m_out_lst = [m_time, m_freq]
870      else:
871        state_out_lst.extend([c_time, m_time])
872        m_out_lst.extend([m_time, m_freq])
873
874    return m_out_lst, state_out_lst
875
876  def _make_tf_features(self, input_feat, slice_offset=0):
877    """Make the frequency features.
878
879    Args:
880      input_feat: input Tensor, 2D, [batch, num_units].
881      slice_offset: (optional) Python int, default 0, the slicing offset is only
882        used for the backward processing in the BidirectionalGridLSTMCell. It
883        specifies a different starting point instead of always 0 to enable the
884        forward and backward processing look at different frequency blocks.
885
886    Returns:
887      A list of frequency features, with each element containing:
888      - A 2D, [batch, output_dim], Tensor representing the time-frequency
889        feature for that frequency index. Here output_dim is feature_size.
890    Raises:
891      ValueError: if input_size cannot be inferred from static shape inference.
892    """
893    input_size = input_feat.get_shape().with_rank(2).dims[-1].value
894    if input_size is None:
895      raise ValueError("Cannot infer input_size from static shape inference.")
896    if slice_offset > 0:
897      # Padding to the end
898      inputs = array_ops.pad(input_feat,
899                             array_ops.constant(
900                                 [0, 0, 0, slice_offset],
901                                 shape=[2, 2],
902                                 dtype=dtypes.int32), "CONSTANT")
903    elif slice_offset < 0:
904      # Padding to the front
905      inputs = array_ops.pad(input_feat,
906                             array_ops.constant(
907                                 [0, 0, -slice_offset, 0],
908                                 shape=[2, 2],
909                                 dtype=dtypes.int32), "CONSTANT")
910      slice_offset = 0
911    else:
912      inputs = input_feat
913    freq_inputs = []
914    if not self._start_freqindex_list:
915      if len(self._num_frequency_blocks) != 1:
916        raise ValueError("Length of num_frequency_blocks"
917                         " is not 1, but instead is %d" %
918                         len(self._num_frequency_blocks))
919      num_feats = int(
920          (input_size - self._feature_size) / (self._frequency_skip)) + 1
921      if num_feats != self._num_frequency_blocks[0]:
922        raise ValueError(
923            "Invalid num_frequency_blocks, requires %d but gets %d, please"
924            " check the input size and filter config are correct." %
925            (self._num_frequency_blocks[0], num_feats))
926      block_inputs = []
927      for f in range(num_feats):
928        cur_input = array_ops.slice(
929            inputs, [0, slice_offset + f * self._frequency_skip],
930            [-1, self._feature_size])
931        block_inputs.append(cur_input)
932      freq_inputs.append(block_inputs)
933    else:
934      if len(self._start_freqindex_list) != len(self._end_freqindex_list):
935        raise ValueError("Length of start and end freqindex_list"
936                         " does not match %d %d",
937                         len(self._start_freqindex_list),
938                         len(self._end_freqindex_list))
939      if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
940        raise ValueError("Length of num_frequency_blocks"
941                         " is not equal to start_freqindex_list %d %d",
942                         len(self._num_frequency_blocks),
943                         len(self._start_freqindex_list))
944      for b in range(len(self._start_freqindex_list)):
945        start_index = self._start_freqindex_list[b]
946        end_index = self._end_freqindex_list[b]
947        cur_size = end_index - start_index
948        block_feats = int(
949            (cur_size - self._feature_size) / (self._frequency_skip)) + 1
950        if block_feats != self._num_frequency_blocks[b]:
951          raise ValueError(
952              "Invalid num_frequency_blocks, requires %d but gets %d, please"
953              " check the input size and filter config are correct." %
954              (self._num_frequency_blocks[b], block_feats))
955        block_inputs = []
956        for f in range(block_feats):
957          cur_input = array_ops.slice(
958              inputs,
959              [0, start_index + slice_offset + f * self._frequency_skip],
960              [-1, self._feature_size])
961          block_inputs.append(cur_input)
962        freq_inputs.append(block_inputs)
963    return freq_inputs
964
965
966class BidirectionalGridLSTMCell(GridLSTMCell):
967  """Bidirectional GridLstm cell.
968
969  The bidirection connection is only used in the frequency direction, which
970  hence doesn't affect the time direction's real-time processing that is
971  required for online recognition systems.
972  The current implementation uses different weights for the two directions.
973  """
974
975  def __init__(self,
976               num_units,
977               use_peepholes=False,
978               share_time_frequency_weights=False,
979               cell_clip=None,
980               initializer=None,
981               num_unit_shards=1,
982               forget_bias=1.0,
983               feature_size=None,
984               frequency_skip=None,
985               num_frequency_blocks=None,
986               start_freqindex_list=None,
987               end_freqindex_list=None,
988               couple_input_forget_gates=False,
989               backward_slice_offset=0,
990               reuse=None):
991    """Initialize the parameters for an LSTM cell.
992
993    Args:
994      num_units: int, The number of units in the LSTM cell
995      use_peepholes: (optional) bool, default False. Set True to enable
996        diagonal/peephole connections.
997      share_time_frequency_weights: (optional) bool, default False. Set True to
998        enable shared cell weights between time and frequency LSTMs.
999      cell_clip: (optional) A float value, default None, if provided the cell
1000        state is clipped by this value prior to the cell output activation.
1001      initializer: (optional) The initializer to use for the weight and
1002        projection matrices, default None.
1003      num_unit_shards: (optional) int, default 1, How to split the weight
1004        matrix. If > 1, the weight matrix is stored across num_unit_shards.
1005      forget_bias: (optional) float, default 1.0, The initial bias of the
1006        forget gates, used to reduce the scale of forgetting at the beginning
1007        of the training.
1008      feature_size: (optional) int, default None, The size of the input feature
1009        the LSTM spans over.
1010      frequency_skip: (optional) int, default None, The amount the LSTM filter
1011        is shifted by in frequency.
1012      num_frequency_blocks: [required] A list of frequency blocks needed to
1013        cover the whole input feature splitting defined by start_freqindex_list
1014        and end_freqindex_list.
1015      start_freqindex_list: [optional], list of ints, default None,  The
1016        starting frequency index for each frequency block.
1017      end_freqindex_list: [optional], list of ints, default None. The ending
1018        frequency index for each frequency block.
1019      couple_input_forget_gates: (optional) bool, default False, Whether to
1020        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
1021        model parameters and computation cost.
1022      backward_slice_offset: (optional) int32, default 0, the starting offset to
1023        slice the feature for backward processing.
1024      reuse: (optional) Python boolean describing whether to reuse variables
1025        in an existing scope.  If not `True`, and the existing scope already has
1026        the given variables, an error is raised.
1027    """
1028    super(BidirectionalGridLSTMCell, self).__init__(
1029        num_units, use_peepholes, share_time_frequency_weights, cell_clip,
1030        initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
1031        num_frequency_blocks, start_freqindex_list, end_freqindex_list,
1032        couple_input_forget_gates, True, reuse)
1033    self._backward_slice_offset = int(backward_slice_offset)
1034    state_names = ""
1035    for direction in ["fwd", "bwd"]:
1036      for block_index in range(len(self._num_frequency_blocks)):
1037        for freq_index in range(self._num_frequency_blocks[block_index]):
1038          name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
1039                                                  block_index)
1040          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
1041    self._state_tuple_type = collections.namedtuple(
1042        "BidirectionalGridLSTMStateTuple", state_names.strip(","))
1043    self._state_size = self._state_tuple_type(*(
1044        [num_units, num_units] * self._total_blocks * 2))
1045    self._output_size = 2 * num_units * self._total_blocks * 2
1046
1047  def call(self, inputs, state):
1048    """Run one step of LSTM.
1049
1050    Args:
1051      inputs: input Tensor, 2D, [batch, num_units].
1052      state: tuple of Tensors, 2D, [batch, state_size].
1053
1054    Returns:
1055      A tuple containing:
1056      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
1057        after reading "inputs" when previous state was "state".
1058        Here output_dim is num_units.
1059      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
1060        after reading "inputs" when previous state was "state".
1061    Raises:
1062      ValueError: if an input_size was specified and the provided inputs have
1063        a different dimension.
1064    """
1065    batch_size = tensor_shape.dimension_value(
1066        inputs.shape[0]) or array_ops.shape(inputs)[0]
1067    fwd_inputs = self._make_tf_features(inputs)
1068    if self._backward_slice_offset:
1069      bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
1070    else:
1071      bwd_inputs = fwd_inputs
1072
1073    # Forward processing
1074    with vs.variable_scope("fwd"):
1075      fwd_m_out_lst = []
1076      fwd_state_out_lst = []
1077      for block in range(len(fwd_inputs)):
1078        fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
1079            fwd_inputs[block],
1080            block,
1081            state,
1082            batch_size,
1083            state_prefix="fwd_state",
1084            state_is_tuple=True)
1085        fwd_m_out_lst.extend(fwd_m_out_lst_current)
1086        fwd_state_out_lst.extend(fwd_state_out_lst_current)
1087    # Backward processing
1088    bwd_m_out_lst = []
1089    bwd_state_out_lst = []
1090    with vs.variable_scope("bwd"):
1091      for block in range(len(bwd_inputs)):
1092        # Reverse the blocks
1093        bwd_inputs_reverse = bwd_inputs[block][::-1]
1094        bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
1095            bwd_inputs_reverse,
1096            block,
1097            state,
1098            batch_size,
1099            state_prefix="bwd_state",
1100            state_is_tuple=True)
1101        bwd_m_out_lst.extend(bwd_m_out_lst_current)
1102        bwd_state_out_lst.extend(bwd_state_out_lst_current)
1103    state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
1104    # Outputs are always concated as it is never used separately.
1105    m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
1106    return m_out, state_out
1107
1108
1109# pylint: disable=protected-access
1110_Linear = core_rnn_cell._Linear  # pylint: disable=invalid-name
1111
1112# pylint: enable=protected-access
1113
1114
1115class AttentionCellWrapper(rnn_cell_impl.RNNCell):
1116  """Basic attention cell wrapper.
1117
1118  Implementation based on https://arxiv.org/abs/1601.06733.
1119  """
1120
1121  def __init__(self,
1122               cell,
1123               attn_length,
1124               attn_size=None,
1125               attn_vec_size=None,
1126               input_size=None,
1127               state_is_tuple=True,
1128               reuse=None):
1129    """Create a cell with attention.
1130
1131    Args:
1132      cell: an RNNCell, an attention is added to it.
1133      attn_length: integer, the size of an attention window.
1134      attn_size: integer, the size of an attention vector. Equal to
1135          cell.output_size by default.
1136      attn_vec_size: integer, the number of convolutional features calculated
1137          on attention state and a size of the hidden layer built from
1138          base cell state. Equal attn_size to by default.
1139      input_size: integer, the size of a hidden linear layer,
1140          built from inputs and attention. Derived from the input tensor
1141          by default.
1142      state_is_tuple: If True, accepted and returned states are n-tuples, where
1143        `n = len(cells)`.  By default (False), the states are all
1144        concatenated along the column axis.
1145      reuse: (optional) Python boolean describing whether to reuse variables
1146        in an existing scope.  If not `True`, and the existing scope already has
1147        the given variables, an error is raised.
1148
1149    Raises:
1150      TypeError: if cell is not an RNNCell.
1151      ValueError: if cell returns a state tuple but the flag
1152          `state_is_tuple` is `False` or if attn_length is zero or less.
1153    """
1154    super(AttentionCellWrapper, self).__init__(_reuse=reuse)
1155    rnn_cell_impl.assert_like_rnncell("cell", cell)
1156    if nest.is_sequence(cell.state_size) and not state_is_tuple:
1157      raise ValueError(
1158          "Cell returns tuple of states, but the flag "
1159          "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
1160    if attn_length <= 0:
1161      raise ValueError(
1162          "attn_length should be greater than zero, got %s" % str(attn_length))
1163    if not state_is_tuple:
1164      logging.warn("%s: Using a concatenated state is slower and will soon be "
1165                   "deprecated.  Use state_is_tuple=True.", self)
1166    if attn_size is None:
1167      attn_size = cell.output_size
1168    if attn_vec_size is None:
1169      attn_vec_size = attn_size
1170    self._state_is_tuple = state_is_tuple
1171    self._cell = cell
1172    self._attn_vec_size = attn_vec_size
1173    self._input_size = input_size
1174    self._attn_size = attn_size
1175    self._attn_length = attn_length
1176    self._reuse = reuse
1177    self._linear1 = None
1178    self._linear2 = None
1179    self._linear3 = None
1180
1181  @property
1182  def state_size(self):
1183    size = (self._cell.state_size, self._attn_size,
1184            self._attn_size * self._attn_length)
1185    if self._state_is_tuple:
1186      return size
1187    else:
1188      return sum(list(size))
1189
1190  @property
1191  def output_size(self):
1192    return self._attn_size
1193
1194  def call(self, inputs, state):
1195    """Long short-term memory cell with attention (LSTMA)."""
1196    if self._state_is_tuple:
1197      state, attns, attn_states = state
1198    else:
1199      states = state
1200      state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
1201      attns = array_ops.slice(states, [0, self._cell.state_size],
1202                              [-1, self._attn_size])
1203      attn_states = array_ops.slice(
1204          states, [0, self._cell.state_size + self._attn_size],
1205          [-1, self._attn_size * self._attn_length])
1206    attn_states = array_ops.reshape(attn_states,
1207                                    [-1, self._attn_length, self._attn_size])
1208    input_size = self._input_size
1209    if input_size is None:
1210      input_size = inputs.get_shape().as_list()[1]
1211    if self._linear1 is None:
1212      self._linear1 = _Linear([inputs, attns], input_size, True)
1213    inputs = self._linear1([inputs, attns])
1214    cell_output, new_state = self._cell(inputs, state)
1215    if self._state_is_tuple:
1216      new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
1217    else:
1218      new_state_cat = new_state
1219    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
1220    with vs.variable_scope("attn_output_projection"):
1221      if self._linear2 is None:
1222        self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True)
1223      output = self._linear2([cell_output, new_attns])
1224    new_attn_states = array_ops.concat(
1225        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
1226    new_attn_states = array_ops.reshape(
1227        new_attn_states, [-1, self._attn_length * self._attn_size])
1228    new_state = (new_state, new_attns, new_attn_states)
1229    if not self._state_is_tuple:
1230      new_state = array_ops.concat(list(new_state), 1)
1231    return output, new_state
1232
1233  def _attention(self, query, attn_states):
1234    conv2d = nn_ops.conv2d
1235    reduce_sum = math_ops.reduce_sum
1236    softmax = nn_ops.softmax
1237    tanh = math_ops.tanh
1238
1239    with vs.variable_scope("attention"):
1240      k = vs.get_variable("attn_w",
1241                          [1, 1, self._attn_size, self._attn_vec_size])
1242      v = vs.get_variable("attn_v", [self._attn_vec_size])
1243      hidden = array_ops.reshape(attn_states,
1244                                 [-1, self._attn_length, 1, self._attn_size])
1245      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
1246      if self._linear3 is None:
1247        self._linear3 = _Linear(query, self._attn_vec_size, True)
1248      y = self._linear3(query)
1249      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
1250      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
1251      a = softmax(s)
1252      d = reduce_sum(
1253          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
1254      new_attns = array_ops.reshape(d, [-1, self._attn_size])
1255      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
1256      return new_attns, new_attn_states
1257
1258
1259class HighwayWrapper(rnn_cell_impl.RNNCell):
1260  """RNNCell wrapper that adds highway connection on cell input and output.
1261
1262  Based on:
1263    R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
1264    arXiv preprint arXiv:1505.00387, 2015.
1265    https://arxiv.org/abs/1505.00387
1266  """
1267
1268  def __init__(self,
1269               cell,
1270               couple_carry_transform_gates=True,
1271               carry_bias_init=1.0):
1272    """Constructs a `HighwayWrapper` for `cell`.
1273
1274    Args:
1275      cell: An instance of `RNNCell`.
1276      couple_carry_transform_gates: boolean, should the Carry and Transform gate
1277        be coupled.
1278      carry_bias_init: float, carry gates bias initialization.
1279    """
1280    self._cell = cell
1281    self._couple_carry_transform_gates = couple_carry_transform_gates
1282    self._carry_bias_init = carry_bias_init
1283
1284  @property
1285  def state_size(self):
1286    return self._cell.state_size
1287
1288  @property
1289  def output_size(self):
1290    return self._cell.output_size
1291
1292  def zero_state(self, batch_size, dtype):
1293    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1294      return self._cell.zero_state(batch_size, dtype)
1295
1296  def _highway(self, inp, out):
1297    input_size = inp.get_shape().with_rank(2).dims[1].value
1298    carry_weight = vs.get_variable("carry_w", [input_size, input_size])
1299    carry_bias = vs.get_variable(
1300        "carry_b", [input_size],
1301        initializer=init_ops.constant_initializer(self._carry_bias_init))
1302    carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
1303    if self._couple_carry_transform_gates:
1304      transform = 1 - carry
1305    else:
1306      transform_weight = vs.get_variable("transform_w",
1307                                         [input_size, input_size])
1308      transform_bias = vs.get_variable(
1309          "transform_b", [input_size],
1310          initializer=init_ops.constant_initializer(-self._carry_bias_init))
1311      transform = math_ops.sigmoid(
1312          nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
1313    return inp * carry + out * transform
1314
1315  def __call__(self, inputs, state, scope=None):
1316    """Run the cell and add its inputs to its outputs.
1317
1318    Args:
1319      inputs: cell inputs.
1320      state: cell state.
1321      scope: optional cell scope.
1322
1323    Returns:
1324      Tuple of cell outputs and new state.
1325
1326    Raises:
1327      TypeError: If cell inputs and outputs have different structure (type).
1328      ValueError: If cell inputs and outputs have different structure (value).
1329    """
1330    outputs, new_state = self._cell(inputs, state, scope=scope)
1331    nest.assert_same_structure(inputs, outputs)
1332
1333    # Ensure shapes match
1334    def assert_shape_match(inp, out):
1335      inp.get_shape().assert_is_compatible_with(out.get_shape())
1336
1337    nest.map_structure(assert_shape_match, inputs, outputs)
1338    res_outputs = nest.map_structure(self._highway, inputs, outputs)
1339    return (res_outputs, new_state)
1340
1341
1342class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
1343  """LSTM unit with layer normalization and recurrent dropout.
1344
1345  This class adds layer normalization and recurrent dropout to a
1346  basic LSTM unit. Layer normalization implementation is based on:
1347
1348    https://arxiv.org/abs/1607.06450.
1349
1350  "Layer Normalization"
1351  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
1352
1353  and is applied before the internal nonlinearities.
1354  Recurrent dropout is base on:
1355
1356    https://arxiv.org/abs/1603.05118
1357
1358  "Recurrent Dropout without Memory Loss"
1359  Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
1360  """
1361
1362  def __init__(self,
1363               num_units,
1364               forget_bias=1.0,
1365               input_size=None,
1366               activation=math_ops.tanh,
1367               layer_norm=True,
1368               norm_gain=1.0,
1369               norm_shift=0.0,
1370               dropout_keep_prob=1.0,
1371               dropout_prob_seed=None,
1372               reuse=None):
1373    """Initializes the basic LSTM cell.
1374
1375    Args:
1376      num_units: int, The number of units in the LSTM cell.
1377      forget_bias: float, The bias added to forget gates (see above).
1378      input_size: Deprecated and unused.
1379      activation: Activation function of the inner states.
1380      layer_norm: If `True`, layer normalization will be applied.
1381      norm_gain: float, The layer normalization gain initial value. If
1382        `layer_norm` has been set to `False`, this argument will be ignored.
1383      norm_shift: float, The layer normalization shift initial value. If
1384        `layer_norm` has been set to `False`, this argument will be ignored.
1385      dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
1386        recurrent dropout probability value. If float and 1.0, no dropout will
1387        be applied.
1388      dropout_prob_seed: (optional) integer, the randomness seed.
1389      reuse: (optional) Python boolean describing whether to reuse variables
1390        in an existing scope.  If not `True`, and the existing scope already has
1391        the given variables, an error is raised.
1392    """
1393    super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
1394
1395    if input_size is not None:
1396      logging.warn("%s: The input_size parameter is deprecated.", self)
1397
1398    self._num_units = num_units
1399    self._activation = activation
1400    self._forget_bias = forget_bias
1401    self._keep_prob = dropout_keep_prob
1402    self._seed = dropout_prob_seed
1403    self._layer_norm = layer_norm
1404    self._norm_gain = norm_gain
1405    self._norm_shift = norm_shift
1406    self._reuse = reuse
1407
1408  @property
1409  def state_size(self):
1410    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1411
1412  @property
1413  def output_size(self):
1414    return self._num_units
1415
1416  def _norm(self, inp, scope, dtype=dtypes.float32):
1417    shape = inp.get_shape()[-1:]
1418    gamma_init = init_ops.constant_initializer(self._norm_gain)
1419    beta_init = init_ops.constant_initializer(self._norm_shift)
1420    with vs.variable_scope(scope):
1421      # Initialize beta and gamma for use by layer_norm.
1422      vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
1423      vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
1424    normalized = layers.layer_norm(inp, reuse=True, scope=scope)
1425    return normalized
1426
1427  def _linear(self, args):
1428    out_size = 4 * self._num_units
1429    proj_size = args.get_shape()[-1]
1430    dtype = args.dtype
1431    weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
1432    out = math_ops.matmul(args, weights)
1433    if not self._layer_norm:
1434      bias = vs.get_variable("bias", [out_size], dtype=dtype)
1435      out = nn_ops.bias_add(out, bias)
1436    return out
1437
1438  def call(self, inputs, state):
1439    """LSTM cell with layer normalization and recurrent dropout."""
1440    c, h = state
1441    args = array_ops.concat([inputs, h], 1)
1442    concat = self._linear(args)
1443    dtype = args.dtype
1444
1445    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
1446    if self._layer_norm:
1447      i = self._norm(i, "input", dtype=dtype)
1448      j = self._norm(j, "transform", dtype=dtype)
1449      f = self._norm(f, "forget", dtype=dtype)
1450      o = self._norm(o, "output", dtype=dtype)
1451
1452    g = self._activation(j)
1453    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
1454      g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
1455
1456    new_c = (
1457        c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
1458    if self._layer_norm:
1459      new_c = self._norm(new_c, "state", dtype=dtype)
1460    new_h = self._activation(new_c) * math_ops.sigmoid(o)
1461
1462    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
1463    return new_h, new_state
1464
1465
1466class NASCell(rnn_cell_impl.LayerRNNCell):
1467  """Neural Architecture Search (NAS) recurrent network cell.
1468
1469  This implements the recurrent cell from the paper:
1470
1471    https://arxiv.org/abs/1611.01578
1472
1473  Barret Zoph and Quoc V. Le.
1474  "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
1475
1476  The class uses an optional projection layer.
1477  """
1478
1479  # NAS cell's architecture base.
1480  _NAS_BASE = 8
1481
1482  def __init__(self, num_units, num_proj=None, use_bias=False, reuse=None,
1483               **kwargs):
1484    """Initialize the parameters for a NAS cell.
1485
1486    Args:
1487      num_units: int, The number of units in the NAS cell.
1488      num_proj: (optional) int, The output dimensionality for the projection
1489        matrices.  If None, no projection is performed.
1490      use_bias: (optional) bool, If True then use biases within the cell. This
1491        is False by default.
1492      reuse: (optional) Python boolean describing whether to reuse variables
1493        in an existing scope.  If not `True`, and the existing scope already has
1494        the given variables, an error is raised.
1495      **kwargs: Additional keyword arguments.
1496    """
1497    super(NASCell, self).__init__(_reuse=reuse, **kwargs)
1498    self._num_units = num_units
1499    self._num_proj = num_proj
1500    self._use_bias = use_bias
1501    self._reuse = reuse
1502
1503    if num_proj is not None:
1504      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
1505      self._output_size = num_proj
1506    else:
1507      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
1508      self._output_size = num_units
1509
1510  @property
1511  def state_size(self):
1512    return self._state_size
1513
1514  @property
1515  def output_size(self):
1516    return self._output_size
1517
1518  def build(self, inputs_shape):
1519    input_size = tensor_shape.dimension_value(
1520        tensor_shape.TensorShape(inputs_shape).with_rank(2)[1])
1521    if input_size is None:
1522      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1523
1524    num_proj = self._num_units if self._num_proj is None else self._num_proj
1525
1526    # Variables for the NAS cell. `recurrent_kernel` is all matrices multiplying
1527    # the hiddenstate and `kernel` is all matrices multiplying the inputs.
1528    self.recurrent_kernel = self.add_variable(
1529        "recurrent_kernel", [num_proj, self._NAS_BASE * self._num_units])
1530    self.kernel = self.add_variable(
1531        "kernel", [input_size, self._NAS_BASE * self._num_units])
1532
1533    if self._use_bias:
1534      self.bias = self.add_variable("bias",
1535                                    shape=[self._NAS_BASE * self._num_units],
1536                                    initializer=init_ops.zeros_initializer)
1537
1538    # Projection layer if specified
1539    if self._num_proj is not None:
1540      self.projection_weights = self.add_variable(
1541          "projection_weights", [self._num_units, self._num_proj])
1542
1543    self.built = True
1544
1545  def call(self, inputs, state):
1546    """Run one step of NAS Cell.
1547
1548    Args:
1549      inputs: input Tensor, 2D, batch x num_units.
1550      state: This must be a tuple of state Tensors, both `2-D`, with column
1551        sizes `c_state` and `m_state`.
1552
1553    Returns:
1554      A tuple containing:
1555      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
1556        NAS Cell after reading `inputs` when previous state was `state`.
1557        Here output_dim is:
1558           num_proj if num_proj was set,
1559           num_units otherwise.
1560      - Tensor(s) representing the new state of NAS Cell after reading `inputs`
1561        when the previous state was `state`.  Same type and shape(s) as `state`.
1562
1563    Raises:
1564      ValueError: If input size cannot be inferred from inputs via
1565        static shape inference.
1566    """
1567    sigmoid = math_ops.sigmoid
1568    tanh = math_ops.tanh
1569    relu = nn_ops.relu
1570
1571    (c_prev, m_prev) = state
1572
1573    m_matrix = math_ops.matmul(m_prev, self.recurrent_kernel)
1574    inputs_matrix = math_ops.matmul(inputs, self.kernel)
1575
1576    if self._use_bias:
1577      m_matrix = nn_ops.bias_add(m_matrix, self.bias)
1578
1579    # The NAS cell branches into 8 different splits for both the hiddenstate
1580    # and the input
1581    m_matrix_splits = array_ops.split(
1582        axis=1, num_or_size_splits=self._NAS_BASE, value=m_matrix)
1583    inputs_matrix_splits = array_ops.split(
1584        axis=1, num_or_size_splits=self._NAS_BASE, value=inputs_matrix)
1585
1586    # First layer
1587    layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
1588    layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
1589    layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
1590    layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
1591    layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
1592    layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
1593    layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
1594    layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
1595
1596    # Second layer
1597    l2_0 = tanh(layer1_0 * layer1_1)
1598    l2_1 = tanh(layer1_2 + layer1_3)
1599    l2_2 = tanh(layer1_4 * layer1_5)
1600    l2_3 = sigmoid(layer1_6 + layer1_7)
1601
1602    # Inject the cell
1603    l2_0 = tanh(l2_0 + c_prev)
1604
1605    # Third layer
1606    l3_0_pre = l2_0 * l2_1
1607    new_c = l3_0_pre  # create new cell
1608    l3_0 = l3_0_pre
1609    l3_1 = tanh(l2_2 + l2_3)
1610
1611    # Final layer
1612    new_m = tanh(l3_0 * l3_1)
1613
1614    # Projection layer if specified
1615    if self._num_proj is not None:
1616      new_m = math_ops.matmul(new_m, self.projection_weights)
1617
1618    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
1619    return new_m, new_state
1620
1621
1622class UGRNNCell(rnn_cell_impl.RNNCell):
1623  """Update Gate Recurrent Neural Network (UGRNN) cell.
1624
1625  Compromise between a LSTM/GRU and a vanilla RNN.  There is only one
1626  gate, and that is to determine whether the unit should be
1627  integrating or computing instantaneously.  This is the recurrent
1628  idea of the feedforward Highway Network.
1629
1630  This implements the recurrent cell from the paper:
1631
1632    https://arxiv.org/abs/1611.09913
1633
1634  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1635  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1636  """
1637
1638  def __init__(self,
1639               num_units,
1640               initializer=None,
1641               forget_bias=1.0,
1642               activation=math_ops.tanh,
1643               reuse=None):
1644    """Initialize the parameters for an UGRNN cell.
1645
1646    Args:
1647      num_units: int, The number of units in the UGRNN cell
1648      initializer: (optional) The initializer to use for the weight matrices.
1649      forget_bias: (optional) float, default 1.0, The initial bias of the
1650        forget gate, used to reduce the scale of forgetting at the beginning
1651        of the training.
1652      activation: (optional) Activation function of the inner states.
1653        Default is `tf.tanh`.
1654      reuse: (optional) Python boolean describing whether to reuse variables
1655        in an existing scope.  If not `True`, and the existing scope already has
1656        the given variables, an error is raised.
1657    """
1658    super(UGRNNCell, self).__init__(_reuse=reuse)
1659    self._num_units = num_units
1660    self._initializer = initializer
1661    self._forget_bias = forget_bias
1662    self._activation = activation
1663    self._reuse = reuse
1664    self._linear = None
1665
1666  @property
1667  def state_size(self):
1668    return self._num_units
1669
1670  @property
1671  def output_size(self):
1672    return self._num_units
1673
1674  def call(self, inputs, state):
1675    """Run one step of UGRNN.
1676
1677    Args:
1678      inputs: input Tensor, 2D, batch x input size.
1679      state: state Tensor, 2D, batch x num units.
1680
1681    Returns:
1682      new_output: batch x num units, Tensor representing the output of the UGRNN
1683        after reading `inputs` when previous state was `state`. Identical to
1684        `new_state`.
1685      new_state: batch x num units, Tensor representing the state of the UGRNN
1686        after reading `inputs` when previous state was `state`.
1687
1688    Raises:
1689      ValueError: If input size cannot be inferred from inputs via
1690        static shape inference.
1691    """
1692    sigmoid = math_ops.sigmoid
1693
1694    input_size = inputs.get_shape().with_rank(2).dims[1]
1695    if input_size.value is None:
1696      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1697
1698    with vs.variable_scope(
1699        vs.get_variable_scope(), initializer=self._initializer):
1700      cell_inputs = array_ops.concat([inputs, state], 1)
1701      if self._linear is None:
1702        self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
1703      rnn_matrix = self._linear(cell_inputs)
1704
1705      [g_act, c_act] = array_ops.split(
1706          axis=1, num_or_size_splits=2, value=rnn_matrix)
1707
1708      c = self._activation(c_act)
1709      g = sigmoid(g_act + self._forget_bias)
1710      new_state = g * state + (1.0 - g) * c
1711      new_output = new_state
1712
1713    return new_output, new_state
1714
1715
1716class IntersectionRNNCell(rnn_cell_impl.RNNCell):
1717  """Intersection Recurrent Neural Network (+RNN) cell.
1718
1719  Architecture with coupled recurrent gate as well as coupled depth
1720  gate, designed to improve information flow through stacked RNNs. As the
1721  architecture uses depth gating, the dimensionality of the depth
1722  output (y) also should not change through depth (input size == output size).
1723  To achieve this, the first layer of a stacked Intersection RNN projects
1724  the inputs to N (num units) dimensions. Therefore when initializing an
1725  IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
1726  and use default settings for subsequent layers.
1727
1728  This implements the recurrent cell from the paper:
1729
1730    https://arxiv.org/abs/1611.09913
1731
1732  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1733  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1734
1735  The Intersection RNN is built for use in deeply stacked
1736  RNNs so it may not achieve best performance with depth 1.
1737  """
1738
1739  def __init__(self,
1740               num_units,
1741               num_in_proj=None,
1742               initializer=None,
1743               forget_bias=1.0,
1744               y_activation=nn_ops.relu,
1745               reuse=None):
1746    """Initialize the parameters for an +RNN cell.
1747
1748    Args:
1749      num_units: int, The number of units in the +RNN cell
1750      num_in_proj: (optional) int, The input dimensionality for the RNN.
1751        If creating the first layer of an +RNN, this should be set to
1752        `num_units`. Otherwise, this should be set to `None` (default).
1753        If `None`, dimensionality of `inputs` should be equal to `num_units`,
1754        otherwise ValueError is thrown.
1755      initializer: (optional) The initializer to use for the weight matrices.
1756      forget_bias: (optional) float, default 1.0, The initial bias of the
1757        forget gates, used to reduce the scale of forgetting at the beginning
1758        of the training.
1759      y_activation: (optional) Activation function of the states passed
1760        through depth. Default is 'tf.nn.relu`.
1761      reuse: (optional) Python boolean describing whether to reuse variables
1762        in an existing scope.  If not `True`, and the existing scope already has
1763        the given variables, an error is raised.
1764    """
1765    super(IntersectionRNNCell, self).__init__(_reuse=reuse)
1766    self._num_units = num_units
1767    self._initializer = initializer
1768    self._forget_bias = forget_bias
1769    self._num_input_proj = num_in_proj
1770    self._y_activation = y_activation
1771    self._reuse = reuse
1772    self._linear1 = None
1773    self._linear2 = None
1774
1775  @property
1776  def state_size(self):
1777    return self._num_units
1778
1779  @property
1780  def output_size(self):
1781    return self._num_units
1782
1783  def call(self, inputs, state):
1784    """Run one step of the Intersection RNN.
1785
1786    Args:
1787      inputs: input Tensor, 2D, batch x input size.
1788      state: state Tensor, 2D, batch x num units.
1789
1790    Returns:
1791      new_y: batch x num units, Tensor representing the output of the +RNN
1792        after reading `inputs` when previous state was `state`.
1793      new_state: batch x num units, Tensor representing the state of the +RNN
1794        after reading `inputs` when previous state was `state`.
1795
1796    Raises:
1797      ValueError: If input size cannot be inferred from `inputs` via
1798        static shape inference.
1799      ValueError: If input size != output size (these must be equal when
1800        using the Intersection RNN).
1801    """
1802    sigmoid = math_ops.sigmoid
1803    tanh = math_ops.tanh
1804
1805    input_size = inputs.get_shape().with_rank(2).dims[1]
1806    if input_size.value is None:
1807      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1808
1809    with vs.variable_scope(
1810        vs.get_variable_scope(), initializer=self._initializer):
1811      # read-in projections (should be used for first layer in deep +RNN
1812      # to transform size of inputs from I --> N)
1813      if input_size.value != self._num_units:
1814        if self._num_input_proj:
1815          with vs.variable_scope("in_projection"):
1816            if self._linear1 is None:
1817              self._linear1 = _Linear(inputs, self._num_units, True)
1818            inputs = self._linear1(inputs)
1819        else:
1820          raise ValueError("Must have input size == output size for "
1821                           "Intersection RNN. To fix, num_in_proj should "
1822                           "be set to num_units at cell init.")
1823
1824      n_dim = i_dim = self._num_units
1825      cell_inputs = array_ops.concat([inputs, state], 1)
1826      if self._linear2 is None:
1827        self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
1828      rnn_matrix = self._linear2(cell_inputs)
1829
1830      gh_act = rnn_matrix[:, :n_dim]  # b x n
1831      h_act = rnn_matrix[:, n_dim:2 * n_dim]  # b x n
1832      gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim]  # b x i
1833      y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim]  # b x i
1834
1835      h = tanh(h_act)
1836      y = self._y_activation(y_act)
1837      gh = sigmoid(gh_act + self._forget_bias)
1838      gy = sigmoid(gy_act + self._forget_bias)
1839
1840      new_state = gh * state + (1.0 - gh) * h  # passed thru time
1841      new_y = gy * inputs + (1.0 - gy) * y  # passed thru depth
1842
1843    return new_y, new_state
1844
1845
1846_REGISTERED_OPS = None
1847
1848
1849class CompiledWrapper(rnn_cell_impl.RNNCell):
1850  """Wraps step execution in an XLA JIT scope."""
1851
1852  def __init__(self, cell, compile_stateful=False):
1853    """Create CompiledWrapper cell.
1854
1855    Args:
1856      cell: Instance of `RNNCell`.
1857      compile_stateful: Whether to compile stateful ops like initializers
1858        and random number generators (default: False).
1859    """
1860    self._cell = cell
1861    self._compile_stateful = compile_stateful
1862
1863  @property
1864  def state_size(self):
1865    return self._cell.state_size
1866
1867  @property
1868  def output_size(self):
1869    return self._cell.output_size
1870
1871  def zero_state(self, batch_size, dtype):
1872    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1873      return self._cell.zero_state(batch_size, dtype)
1874
1875  def __call__(self, inputs, state, scope=None):
1876    if self._compile_stateful:
1877      compile_ops = True
1878    else:
1879
1880      def compile_ops(node_def):
1881        global _REGISTERED_OPS
1882        if _REGISTERED_OPS is None:
1883          _REGISTERED_OPS = op_def_registry.get_registered_ops()
1884        return not _REGISTERED_OPS[node_def.op].is_stateful
1885
1886    with jit.experimental_jit_scope(compile_ops=compile_ops):
1887      return self._cell(inputs, state, scope=scope)
1888
1889
1890def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
1891  """Returns an exponential distribution initializer.
1892
1893  Args:
1894    minval: float or a scalar float Tensor. With value > 0. Lower bound of the
1895        range of random values to generate.
1896    maxval: float or a scalar float Tensor. With value > minval. Upper bound of
1897        the range of random values to generate.
1898    seed: An integer. Used to create random seeds.
1899    dtype: The data type.
1900
1901  Returns:
1902    An initializer that generates tensors with an exponential distribution.
1903  """
1904
1905  def _initializer(shape, dtype=dtype, partition_info=None):
1906    del partition_info  # Unused.
1907    return math_ops.exp(
1908        random_ops.random_uniform(
1909            shape, math_ops.log(minval), math_ops.log(maxval), dtype,
1910            seed=seed))
1911
1912  return _initializer
1913
1914
1915class PhasedLSTMCell(rnn_cell_impl.RNNCell):
1916  """Phased LSTM recurrent network cell.
1917
1918  https://arxiv.org/pdf/1610.09513v1.pdf
1919  """
1920
1921  def __init__(self,
1922               num_units,
1923               use_peepholes=False,
1924               leak=0.001,
1925               ratio_on=0.1,
1926               trainable_ratio_on=True,
1927               period_init_min=1.0,
1928               period_init_max=1000.0,
1929               reuse=None):
1930    """Initialize the Phased LSTM cell.
1931
1932    Args:
1933      num_units: int, The number of units in the Phased LSTM cell.
1934      use_peepholes: bool, set True to enable peephole connections.
1935      leak: float or scalar float Tensor with value in [0, 1]. Leak applied
1936          during training.
1937      ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
1938          period during which the gates are open.
1939      trainable_ratio_on: bool, weather ratio_on is trainable.
1940      period_init_min: float or scalar float Tensor. With value > 0.
1941          Minimum value of the initialized period.
1942          The period values are initialized by drawing from the distribution:
1943          e^U(log(period_init_min), log(period_init_max))
1944          Where U(.,.) is the uniform distribution.
1945      period_init_max: float or scalar float Tensor.
1946          With value > period_init_min. Maximum value of the initialized period.
1947      reuse: (optional) Python boolean describing whether to reuse variables
1948        in an existing scope. If not `True`, and the existing scope already has
1949        the given variables, an error is raised.
1950    """
1951    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
1952    self._num_units = num_units
1953    self._use_peepholes = use_peepholes
1954    self._leak = leak
1955    self._ratio_on = ratio_on
1956    self._trainable_ratio_on = trainable_ratio_on
1957    self._period_init_min = period_init_min
1958    self._period_init_max = period_init_max
1959    self._reuse = reuse
1960    self._linear1 = None
1961    self._linear2 = None
1962    self._linear3 = None
1963
1964  @property
1965  def state_size(self):
1966    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1967
1968  @property
1969  def output_size(self):
1970    return self._num_units
1971
1972  def _mod(self, x, y):
1973    """Modulo function that propagates x gradients."""
1974    return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
1975
1976  def _get_cycle_ratio(self, time, phase, period):
1977    """Compute the cycle ratio in the dtype of the time."""
1978    phase_casted = math_ops.cast(phase, dtype=time.dtype)
1979    period_casted = math_ops.cast(period, dtype=time.dtype)
1980    shifted_time = time - phase_casted
1981    cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
1982    return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
1983
1984  def call(self, inputs, state):
1985    """Phased LSTM Cell.
1986
1987    Args:
1988      inputs: A tuple of 2 Tensor.
1989         The first Tensor has shape [batch, 1], and type float32 or float64.
1990         It stores the time.
1991         The second Tensor has shape [batch, features_size], and type float32.
1992         It stores the features.
1993      state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
1994
1995    Returns:
1996      A tuple containing:
1997      - A Tensor of float32, and shape [batch_size, num_units], representing the
1998        output of the cell.
1999      - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
2000        [batch_size, num_units], representing the new state and the output.
2001    """
2002    (c_prev, h_prev) = state
2003    (time, x) = inputs
2004
2005    in_mask_gates = [x, h_prev]
2006    if self._use_peepholes:
2007      in_mask_gates.append(c_prev)
2008
2009    with vs.variable_scope("mask_gates"):
2010      if self._linear1 is None:
2011        self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
2012
2013      mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
2014      [input_gate, forget_gate] = array_ops.split(
2015          axis=1, num_or_size_splits=2, value=mask_gates)
2016
2017    with vs.variable_scope("new_input"):
2018      if self._linear2 is None:
2019        self._linear2 = _Linear([x, h_prev], self._num_units, True)
2020      new_input = math_ops.tanh(self._linear2([x, h_prev]))
2021
2022    new_c = (c_prev * forget_gate + input_gate * new_input)
2023
2024    in_out_gate = [x, h_prev]
2025    if self._use_peepholes:
2026      in_out_gate.append(new_c)
2027
2028    with vs.variable_scope("output_gate"):
2029      if self._linear3 is None:
2030        self._linear3 = _Linear(in_out_gate, self._num_units, True)
2031      output_gate = math_ops.sigmoid(self._linear3(in_out_gate))
2032
2033    new_h = math_ops.tanh(new_c) * output_gate
2034
2035    period = vs.get_variable(
2036        "period", [self._num_units],
2037        initializer=_random_exp_initializer(self._period_init_min,
2038                                            self._period_init_max))
2039    phase = vs.get_variable(
2040        "phase", [self._num_units],
2041        initializer=init_ops.random_uniform_initializer(0.,
2042                                                        period.initial_value))
2043    ratio_on = vs.get_variable(
2044        "ratio_on", [self._num_units],
2045        initializer=init_ops.constant_initializer(self._ratio_on),
2046        trainable=self._trainable_ratio_on)
2047
2048    cycle_ratio = self._get_cycle_ratio(time, phase, period)
2049
2050    k_up = 2 * cycle_ratio / ratio_on
2051    k_down = 2 - k_up
2052    k_closed = self._leak * cycle_ratio
2053
2054    k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
2055    k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
2056
2057    new_c = k * new_c + (1 - k) * c_prev
2058    new_h = k * new_h + (1 - k) * h_prev
2059
2060    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
2061
2062    return new_h, new_state
2063
2064
2065class ConvLSTMCell(rnn_cell_impl.RNNCell):
2066  """Convolutional LSTM recurrent network cell.
2067
2068  https://arxiv.org/pdf/1506.04214v1.pdf
2069  """
2070
2071  def __init__(self,
2072               conv_ndims,
2073               input_shape,
2074               output_channels,
2075               kernel_shape,
2076               use_bias=True,
2077               skip_connection=False,
2078               forget_bias=1.0,
2079               initializers=None,
2080               name="conv_lstm_cell"):
2081    """Construct ConvLSTMCell.
2082
2083    Args:
2084      conv_ndims: Convolution dimensionality (1, 2 or 3).
2085      input_shape: Shape of the input as int tuple, excluding the batch size.
2086      output_channels: int, number of output channels of the conv LSTM.
2087      kernel_shape: Shape of kernel as an int tuple (of size 1, 2 or 3).
2088      use_bias: (bool) Use bias in convolutions.
2089      skip_connection: If set to `True`, concatenate the input to the
2090        output of the conv LSTM. Default: `False`.
2091      forget_bias: Forget bias.
2092      initializers: Unused.
2093      name: Name of the module.
2094
2095    Raises:
2096      ValueError: If `skip_connection` is `True` and stride is different from 1
2097        or if `input_shape` is incompatible with `conv_ndims`.
2098    """
2099    super(ConvLSTMCell, self).__init__(name=name)
2100
2101    if conv_ndims != len(input_shape) - 1:
2102      raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
2103          input_shape, conv_ndims))
2104
2105    self._conv_ndims = conv_ndims
2106    self._input_shape = input_shape
2107    self._output_channels = output_channels
2108    self._kernel_shape = list(kernel_shape)
2109    self._use_bias = use_bias
2110    self._forget_bias = forget_bias
2111    self._skip_connection = skip_connection
2112
2113    self._total_output_channels = output_channels
2114    if self._skip_connection:
2115      self._total_output_channels += self._input_shape[-1]
2116
2117    state_size = tensor_shape.TensorShape(
2118        self._input_shape[:-1] + [self._output_channels])
2119    self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
2120    self._output_size = tensor_shape.TensorShape(
2121        self._input_shape[:-1] + [self._total_output_channels])
2122
2123  @property
2124  def output_size(self):
2125    return self._output_size
2126
2127  @property
2128  def state_size(self):
2129    return self._state_size
2130
2131  def call(self, inputs, state, scope=None):
2132    cell, hidden = state
2133    new_hidden = _conv([inputs, hidden], self._kernel_shape,
2134                       4 * self._output_channels, self._use_bias)
2135    gates = array_ops.split(
2136        value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
2137
2138    input_gate, new_input, forget_gate, output_gate = gates
2139    new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
2140    new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
2141    output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
2142
2143    if self._skip_connection:
2144      output = array_ops.concat([output, inputs], axis=-1)
2145    new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
2146    return output, new_state
2147
2148
2149class Conv1DLSTMCell(ConvLSTMCell):
2150  """1D Convolutional LSTM recurrent network cell.
2151
2152  https://arxiv.org/pdf/1506.04214v1.pdf
2153  """
2154
2155  def __init__(self, name="conv_1d_lstm_cell", **kwargs):
2156    """Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
2157    super(Conv1DLSTMCell, self).__init__(conv_ndims=1, name=name, **kwargs)
2158
2159
2160class Conv2DLSTMCell(ConvLSTMCell):
2161  """2D Convolutional LSTM recurrent network cell.
2162
2163  https://arxiv.org/pdf/1506.04214v1.pdf
2164  """
2165
2166  def __init__(self, name="conv_2d_lstm_cell", **kwargs):
2167    """Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
2168    super(Conv2DLSTMCell, self).__init__(conv_ndims=2, name=name, **kwargs)
2169
2170
2171class Conv3DLSTMCell(ConvLSTMCell):
2172  """3D Convolutional LSTM recurrent network cell.
2173
2174  https://arxiv.org/pdf/1506.04214v1.pdf
2175  """
2176
2177  def __init__(self, name="conv_3d_lstm_cell", **kwargs):
2178    """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
2179    super(Conv3DLSTMCell, self).__init__(conv_ndims=3, name=name, **kwargs)
2180
2181
2182def _conv(args, filter_size, num_features, bias, bias_start=0.0):
2183  """Convolution.
2184
2185  Args:
2186    args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
2187    batch x n, Tensors.
2188    filter_size: int tuple of filter shape (of size 1, 2 or 3).
2189    num_features: int, number of features.
2190    bias: Whether to use biases in the convolution layer.
2191    bias_start: starting value to initialize the bias; 0 by default.
2192
2193  Returns:
2194    A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
2195
2196  Raises:
2197    ValueError: if some of the arguments has unspecified or wrong shape.
2198  """
2199
2200  # Calculate the total size of arguments on dimension 1.
2201  total_arg_size_depth = 0
2202  shapes = [a.get_shape().as_list() for a in args]
2203  shape_length = len(shapes[0])
2204  for shape in shapes:
2205    if len(shape) not in [3, 4, 5]:
2206      raise ValueError("Conv Linear expects 3D, 4D "
2207                       "or 5D arguments: %s" % str(shapes))
2208    if len(shape) != len(shapes[0]):
2209      raise ValueError("Conv Linear expects all args "
2210                       "to be of same Dimension: %s" % str(shapes))
2211    else:
2212      total_arg_size_depth += shape[-1]
2213  dtype = [a.dtype for a in args][0]
2214
2215  # determine correct conv operation
2216  if shape_length == 3:
2217    conv_op = nn_ops.conv1d
2218    strides = 1
2219  elif shape_length == 4:
2220    conv_op = nn_ops.conv2d
2221    strides = shape_length * [1]
2222  elif shape_length == 5:
2223    conv_op = nn_ops.conv3d
2224    strides = shape_length * [1]
2225
2226  # Now the computation.
2227  kernel = vs.get_variable(
2228      "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
2229  if len(args) == 1:
2230    res = conv_op(args[0], kernel, strides, padding="SAME")
2231  else:
2232    res = conv_op(
2233        array_ops.concat(axis=shape_length - 1, values=args),
2234        kernel,
2235        strides,
2236        padding="SAME")
2237  if not bias:
2238    return res
2239  bias_term = vs.get_variable(
2240      "biases", [num_features],
2241      dtype=dtype,
2242      initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
2243  return res + bias_term
2244
2245
2246class GLSTMCell(rnn_cell_impl.RNNCell):
2247  """Group LSTM cell (G-LSTM).
2248
2249  The implementation is based on:
2250
2251    https://arxiv.org/abs/1703.10722
2252
2253  O. Kuchaiev and B. Ginsburg
2254  "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
2255
2256  In brief, a G-LSTM cell consists of one LSTM sub-cell per group, where each
2257  sub-cell operates on an evenly-sized sub-vector of the input and produces an
2258  evenly-sized sub-vector of the output.  For example, a G-LSTM cell with 128
2259  units and 4 groups consists of 4 LSTMs sub-cells with 32 units each.  If that
2260  G-LSTM cell is fed a 200-dim input, then each sub-cell receives a 50-dim part
2261  of the input and produces a 32-dim part of the output.
2262  """
2263
2264  def __init__(self,
2265               num_units,
2266               initializer=None,
2267               num_proj=None,
2268               number_of_groups=1,
2269               forget_bias=1.0,
2270               activation=math_ops.tanh,
2271               reuse=None):
2272    """Initialize the parameters of G-LSTM cell.
2273
2274    Args:
2275      num_units: int, The number of units in the G-LSTM cell
2276      initializer: (optional) The initializer to use for the weight and
2277        projection matrices.
2278      num_proj: (optional) int, The output dimensionality for the projection
2279        matrices.  If None, no projection is performed.
2280      number_of_groups: (optional) int, number of groups to use.
2281        If `number_of_groups` is 1, then it should be equivalent to LSTM cell
2282      forget_bias: Biases of the forget gate are initialized by default to 1
2283        in order to reduce the scale of forgetting at the beginning of
2284        the training.
2285      activation: Activation function of the inner states.
2286      reuse: (optional) Python boolean describing whether to reuse variables
2287        in an existing scope.  If not `True`, and the existing scope already
2288        has the given variables, an error is raised.
2289
2290    Raises:
2291      ValueError: If `num_units` or `num_proj` is not divisible by
2292        `number_of_groups`.
2293    """
2294    super(GLSTMCell, self).__init__(_reuse=reuse)
2295    self._num_units = num_units
2296    self._initializer = initializer
2297    self._num_proj = num_proj
2298    self._forget_bias = forget_bias
2299    self._activation = activation
2300    self._number_of_groups = number_of_groups
2301
2302    if self._num_units % self._number_of_groups != 0:
2303      raise ValueError("num_units must be divisible by number_of_groups")
2304    if self._num_proj:
2305      if self._num_proj % self._number_of_groups != 0:
2306        raise ValueError("num_proj must be divisible by number_of_groups")
2307      self._group_shape = [
2308          int(self._num_proj / self._number_of_groups),
2309          int(self._num_units / self._number_of_groups)
2310      ]
2311    else:
2312      self._group_shape = [
2313          int(self._num_units / self._number_of_groups),
2314          int(self._num_units / self._number_of_groups)
2315      ]
2316
2317    if num_proj:
2318      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2319      self._output_size = num_proj
2320    else:
2321      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2322      self._output_size = num_units
2323    self._linear1 = [None] * number_of_groups
2324    self._linear2 = None
2325
2326  @property
2327  def state_size(self):
2328    return self._state_size
2329
2330  @property
2331  def output_size(self):
2332    return self._output_size
2333
2334  def _get_input_for_group(self, inputs, group_id, group_size):
2335    """Slices inputs into groups to prepare for processing by cell's groups.
2336
2337    Args:
2338      inputs: cell input or it's previous state,
2339              a Tensor, 2D, [batch x num_units]
2340      group_id: group id, a Scalar, for which to prepare input
2341      group_size: size of the group
2342
2343    Returns:
2344      subset of inputs corresponding to group "group_id",
2345      a Tensor, 2D, [batch x num_units/number_of_groups]
2346    """
2347    return array_ops.slice(
2348        input_=inputs,
2349        begin=[0, group_id * group_size],
2350        size=[self._batch_size, group_size],
2351        name=("GLSTM_group%d_input_generation" % group_id))
2352
2353  def call(self, inputs, state):
2354    """Run one step of G-LSTM.
2355
2356    Args:
2357      inputs: input Tensor, 2D, [batch x num_inputs].  num_inputs must be
2358        statically-known and evenly divisible into groups.  The innermost
2359        vectors of the inputs are split into evenly-sized sub-vectors and fed
2360        into the per-group LSTM sub-cells.
2361      state: this must be a tuple of state Tensors, both `2-D`, with column
2362        sizes `c_state` and `m_state`.
2363
2364    Returns:
2365      A tuple containing:
2366
2367      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2368        G-LSTM after reading `inputs` when previous state was `state`.
2369        Here output_dim is:
2370           num_proj if num_proj was set,
2371           num_units otherwise.
2372      - LSTMStateTuple representing the new state of G-LSTM cell
2373        after reading `inputs` when the previous state was `state`.
2374
2375    Raises:
2376      ValueError: If input size cannot be inferred from inputs via
2377        static shape inference, or if the input shape is incompatible
2378        with the number of groups.
2379    """
2380    (c_prev, m_prev) = state
2381
2382    self._batch_size = tensor_shape.dimension_value(
2383        inputs.shape[0]) or array_ops.shape(inputs)[0]
2384
2385    # If the input size is statically-known, calculate and validate its group
2386    # size.  Otherwise, use the output group size.
2387    input_size = tensor_shape.dimension_value(inputs.shape[1])
2388    if input_size is None:
2389      raise ValueError("input size must be statically known")
2390    if input_size % self._number_of_groups != 0:
2391      raise ValueError(
2392          "input size (%d) must be divisible by number_of_groups (%d)" %
2393          (input_size, self._number_of_groups))
2394    input_group_size = int(input_size / self._number_of_groups)
2395
2396    dtype = inputs.dtype
2397    scope = vs.get_variable_scope()
2398    with vs.variable_scope(scope, initializer=self._initializer):
2399      i_parts = []
2400      j_parts = []
2401      f_parts = []
2402      o_parts = []
2403
2404      for group_id in range(self._number_of_groups):
2405        with vs.variable_scope("group%d" % group_id):
2406          x_g_id = array_ops.concat(
2407              [
2408                  self._get_input_for_group(inputs, group_id, input_group_size),
2409                  self._get_input_for_group(m_prev, group_id,
2410                                            self._group_shape[0])
2411              ],
2412              axis=1)
2413          linear = self._linear1[group_id]
2414          if linear is None:
2415            linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
2416            self._linear1[group_id] = linear
2417          R_k = linear(x_g_id)  # pylint: disable=invalid-name
2418          i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
2419
2420        i_parts.append(i_k)
2421        j_parts.append(j_k)
2422        f_parts.append(f_k)
2423        o_parts.append(o_k)
2424
2425      bi = vs.get_variable(
2426          name="bias_i",
2427          shape=[self._num_units],
2428          dtype=dtype,
2429          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2430      bj = vs.get_variable(
2431          name="bias_j",
2432          shape=[self._num_units],
2433          dtype=dtype,
2434          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2435      bf = vs.get_variable(
2436          name="bias_f",
2437          shape=[self._num_units],
2438          dtype=dtype,
2439          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2440      bo = vs.get_variable(
2441          name="bias_o",
2442          shape=[self._num_units],
2443          dtype=dtype,
2444          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2445
2446      i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
2447      j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
2448      f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
2449      o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
2450
2451    c = (
2452        math_ops.sigmoid(f + self._forget_bias) * c_prev +
2453        math_ops.sigmoid(i) * math_ops.tanh(j))
2454    m = math_ops.sigmoid(o) * self._activation(c)
2455
2456    if self._num_proj is not None:
2457      with vs.variable_scope("projection"):
2458        if self._linear2 is None:
2459          self._linear2 = _Linear(m, self._num_proj, False)
2460        m = self._linear2(m)
2461
2462    new_state = rnn_cell_impl.LSTMStateTuple(c, m)
2463    return m, new_state
2464
2465
2466class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
2467  """Long short-term memory unit (LSTM) recurrent network cell.
2468
2469  The default non-peephole implementation is based on:
2470
2471    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
2472
2473  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
2474  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
2475
2476  The peephole implementation is based on:
2477
2478    https://research.google.com/pubs/archive/43905.pdf
2479
2480  Hasim Sak, Andrew Senior, and Francoise Beaufays.
2481  "Long short-term memory recurrent neural network architectures for
2482   large scale acoustic modeling." INTERSPEECH, 2014.
2483
2484  The class uses optional peep-hole connections, optional cell clipping, and
2485  an optional projection layer.
2486
2487  Layer normalization implementation is based on:
2488
2489    https://arxiv.org/abs/1607.06450.
2490
2491  "Layer Normalization"
2492  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
2493
2494  and is applied before the internal nonlinearities.
2495
2496  """
2497
2498  def __init__(self,
2499               num_units,
2500               use_peepholes=False,
2501               cell_clip=None,
2502               initializer=None,
2503               num_proj=None,
2504               proj_clip=None,
2505               forget_bias=1.0,
2506               activation=None,
2507               layer_norm=False,
2508               norm_gain=1.0,
2509               norm_shift=0.0,
2510               reuse=None):
2511    """Initialize the parameters for an LSTM cell.
2512
2513    Args:
2514      num_units: int, The number of units in the LSTM cell
2515      use_peepholes: bool, set True to enable diagonal/peephole connections.
2516      cell_clip: (optional) A float value, if provided the cell state is clipped
2517        by this value prior to the cell output activation.
2518      initializer: (optional) The initializer to use for the weight and
2519        projection matrices.
2520      num_proj: (optional) int, The output dimensionality for the projection
2521        matrices.  If None, no projection is performed.
2522      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
2523        provided, then the projected values are clipped elementwise to within
2524        `[-proj_clip, proj_clip]`.
2525      forget_bias: Biases of the forget gate are initialized by default to 1
2526        in order to reduce the scale of forgetting at the beginning of
2527        the training. Must set it manually to `0.0` when restoring from
2528        CudnnLSTM trained checkpoints.
2529      activation: Activation function of the inner states.  Default: `tanh`.
2530      layer_norm: If `True`, layer normalization will be applied.
2531      norm_gain: float, The layer normalization gain initial value. If
2532        `layer_norm` has been set to `False`, this argument will be ignored.
2533      norm_shift: float, The layer normalization shift initial value. If
2534        `layer_norm` has been set to `False`, this argument will be ignored.
2535      reuse: (optional) Python boolean describing whether to reuse variables
2536        in an existing scope.  If not `True`, and the existing scope already has
2537        the given variables, an error is raised.
2538
2539      When restoring from CudnnLSTM-trained checkpoints, must use
2540      CudnnCompatibleLSTMCell instead.
2541    """
2542    super(LayerNormLSTMCell, self).__init__(_reuse=reuse)
2543
2544    self._num_units = num_units
2545    self._use_peepholes = use_peepholes
2546    self._cell_clip = cell_clip
2547    self._initializer = initializer
2548    self._num_proj = num_proj
2549    self._proj_clip = proj_clip
2550    self._forget_bias = forget_bias
2551    self._activation = activation or math_ops.tanh
2552    self._layer_norm = layer_norm
2553    self._norm_gain = norm_gain
2554    self._norm_shift = norm_shift
2555
2556    if num_proj:
2557      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
2558      self._output_size = num_proj
2559    else:
2560      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
2561      self._output_size = num_units
2562
2563  @property
2564  def state_size(self):
2565    return self._state_size
2566
2567  @property
2568  def output_size(self):
2569    return self._output_size
2570
2571  def _linear(self,
2572              args,
2573              output_size,
2574              bias,
2575              bias_initializer=None,
2576              kernel_initializer=None,
2577              layer_norm=False):
2578    """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.
2579
2580    Args:
2581      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2582      output_size: int, second dimension of W[i].
2583      bias: boolean, whether to add a bias term or not.
2584      bias_initializer: starting value to initialize the bias
2585        (default is all zeros).
2586      kernel_initializer: starting value to initialize the weight.
2587      layer_norm: boolean, whether to apply layer normalization.
2588
2589
2590    Returns:
2591      A 2D Tensor with shape [batch x output_size] taking value
2592      sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.
2593
2594    Raises:
2595      ValueError: if some of the arguments has unspecified or wrong shape.
2596    """
2597    if args is None or (nest.is_sequence(args) and not args):
2598      raise ValueError("`args` must be specified")
2599    if not nest.is_sequence(args):
2600      args = [args]
2601
2602    # Calculate the total size of arguments on dimension 1.
2603    total_arg_size = 0
2604    shapes = [a.get_shape() for a in args]
2605    for shape in shapes:
2606      if shape.ndims != 2:
2607        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2608      if tensor_shape.dimension_value(shape[1]) is None:
2609        raise ValueError("linear expects shape[1] to be provided for shape %s, "
2610                         "but saw %s" % (shape, shape[1]))
2611      else:
2612        total_arg_size += tensor_shape.dimension_value(shape[1])
2613
2614    dtype = [a.dtype for a in args][0]
2615
2616    # Now the computation.
2617    scope = vs.get_variable_scope()
2618    with vs.variable_scope(scope) as outer_scope:
2619      weights = vs.get_variable(
2620          "kernel", [total_arg_size, output_size],
2621          dtype=dtype,
2622          initializer=kernel_initializer)
2623      if len(args) == 1:
2624        res = math_ops.matmul(args[0], weights)
2625      else:
2626        res = math_ops.matmul(array_ops.concat(args, 1), weights)
2627      if not bias:
2628        return res
2629      with vs.variable_scope(outer_scope) as inner_scope:
2630        inner_scope.set_partitioner(None)
2631        if bias_initializer is None:
2632          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2633        biases = vs.get_variable(
2634            "bias", [output_size], dtype=dtype, initializer=bias_initializer)
2635
2636    if not layer_norm:
2637      res = nn_ops.bias_add(res, biases)
2638
2639    return res
2640
2641  def call(self, inputs, state):
2642    """Run one step of LSTM.
2643
2644    Args:
2645      inputs: input Tensor, 2D, batch x num_units.
2646      state: this must be a tuple of state Tensors,
2647       both `2-D`, with column sizes `c_state` and
2648        `m_state`.
2649
2650    Returns:
2651      A tuple containing:
2652
2653      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2654        LSTM after reading `inputs` when previous state was `state`.
2655        Here output_dim is:
2656           num_proj if num_proj was set,
2657           num_units otherwise.
2658      - Tensor(s) representing the new state of LSTM after reading `inputs` when
2659        the previous state was `state`.  Same type and shape(s) as `state`.
2660
2661    Raises:
2662      ValueError: If input size cannot be inferred from inputs via
2663        static shape inference.
2664    """
2665    sigmoid = math_ops.sigmoid
2666
2667    (c_prev, m_prev) = state
2668
2669    dtype = inputs.dtype
2670    input_size = inputs.get_shape().with_rank(2).dims[1]
2671    if input_size.value is None:
2672      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
2673    scope = vs.get_variable_scope()
2674    with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
2675
2676      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
2677      lstm_matrix = self._linear(
2678          [inputs, m_prev],
2679          4 * self._num_units,
2680          bias=True,
2681          bias_initializer=None,
2682          layer_norm=self._layer_norm)
2683      i, j, f, o = array_ops.split(
2684          value=lstm_matrix, num_or_size_splits=4, axis=1)
2685
2686      if self._layer_norm:
2687        i = _norm(self._norm_gain, self._norm_shift, i, "input")
2688        j = _norm(self._norm_gain, self._norm_shift, j, "transform")
2689        f = _norm(self._norm_gain, self._norm_shift, f, "forget")
2690        o = _norm(self._norm_gain, self._norm_shift, o, "output")
2691
2692      # Diagonal connections
2693      if self._use_peepholes:
2694        with vs.variable_scope(unit_scope):
2695          w_f_diag = vs.get_variable(
2696              "w_f_diag", shape=[self._num_units], dtype=dtype)
2697          w_i_diag = vs.get_variable(
2698              "w_i_diag", shape=[self._num_units], dtype=dtype)
2699          w_o_diag = vs.get_variable(
2700              "w_o_diag", shape=[self._num_units], dtype=dtype)
2701
2702      if self._use_peepholes:
2703        c = (
2704            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
2705            sigmoid(i + w_i_diag * c_prev) * self._activation(j))
2706      else:
2707        c = (
2708            sigmoid(f + self._forget_bias) * c_prev +
2709            sigmoid(i) * self._activation(j))
2710
2711      if self._layer_norm:
2712        c = _norm(self._norm_gain, self._norm_shift, c, "state")
2713
2714      if self._cell_clip is not None:
2715        # pylint: disable=invalid-unary-operand-type
2716        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
2717        # pylint: enable=invalid-unary-operand-type
2718      if self._use_peepholes:
2719        m = sigmoid(o + w_o_diag * c) * self._activation(c)
2720      else:
2721        m = sigmoid(o) * self._activation(c)
2722
2723      if self._num_proj is not None:
2724        with vs.variable_scope("projection"):
2725          m = self._linear(m, self._num_proj, bias=False)
2726
2727        if self._proj_clip is not None:
2728          # pylint: disable=invalid-unary-operand-type
2729          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
2730          # pylint: enable=invalid-unary-operand-type
2731
2732    new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
2733    return m, new_state
2734
2735
2736class SRUCell(rnn_cell_impl.LayerRNNCell):
2737  """SRU, Simple Recurrent Unit.
2738
2739     Implementation based on
2740     Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
2741
2742     This variation of RNN cell is characterized by the simplified data
2743     dependence
2744     between hidden states of two consecutive time steps. Traditionally, hidden
2745     states from a cell at time step t-1 needs to be multiplied with a matrix
2746     W_hh before being fed into the ensuing cell at time step t.
2747     This flavor of RNN replaces the matrix multiplication between h_{t-1}
2748     and W_hh with a pointwise multiplication, resulting in performance
2749     gain.
2750
2751  Args:
2752    num_units: int, The number of units in the SRU cell.
2753    activation: Nonlinearity to use.  Default: `tanh`.
2754    reuse: (optional) Python boolean describing whether to reuse variables
2755      in an existing scope.  If not `True`, and the existing scope already has
2756      the given variables, an error is raised.
2757    name: (optional) String, the name of the layer. Layers with the same name
2758      will share weights, but to avoid mistakes we require reuse=True in such
2759      cases.
2760    **kwargs: Additional keyword arguments.
2761  """
2762
2763  def __init__(self, num_units, activation=None, reuse=None, name=None,
2764               **kwargs):
2765    super(SRUCell, self).__init__(_reuse=reuse, name=name, **kwargs)
2766    self._num_units = num_units
2767    self._activation = activation or math_ops.tanh
2768
2769    # Restrict inputs to be 2-dimensional matrices
2770    self.input_spec = input_spec.InputSpec(ndim=2)
2771
2772  @property
2773  def state_size(self):
2774    return self._num_units
2775
2776  @property
2777  def output_size(self):
2778    return self._num_units
2779
2780  def build(self, inputs_shape):
2781    if tensor_shape.dimension_value(inputs_shape[1]) is None:
2782      raise ValueError(
2783          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
2784
2785    input_depth = tensor_shape.dimension_value(inputs_shape[1])
2786
2787    # pylint: disable=protected-access
2788    self._kernel = self.add_variable(
2789        rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2790        shape=[input_depth, 4 * self._num_units])
2791    # pylint: enable=protected-access
2792    self._bias = self.add_variable(
2793        rnn_cell_impl._BIAS_VARIABLE_NAME,  # pylint: disable=protected-access
2794        shape=[2 * self._num_units],
2795        initializer=init_ops.zeros_initializer)
2796
2797    self._built = True
2798
2799  def call(self, inputs, state):
2800    """Simple recurrent unit (SRU) with num_units cells."""
2801
2802    U = math_ops.matmul(inputs, self._kernel)  # pylint: disable=invalid-name
2803    x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
2804        value=U, num_or_size_splits=4, axis=1)
2805
2806    f_r = math_ops.sigmoid(
2807        nn_ops.bias_add(
2808            array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
2809    f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
2810
2811    c = f * state + (1.0 - f) * x_bar
2812    h = r * self._activation(c) + (1.0 - r) * x_tx
2813
2814    return h, c
2815
2816
2817class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
2818  """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.
2819
2820    The weight-norm implementation is based on:
2821    https://arxiv.org/abs/1602.07868
2822    Tim Salimans, Diederik P. Kingma.
2823    Weight Normalization: A Simple Reparameterization to Accelerate
2824    Training of Deep Neural Networks
2825
2826    The default LSTM implementation based on:
2827
2828      https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
2829
2830    Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
2831    "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
2832
2833    The class uses optional peephole connections, optional cell clipping
2834    and an optional projection layer.
2835
2836    The optional peephole implementation is based on:
2837    https://research.google.com/pubs/archive/43905.pdf
2838    Hasim Sak, Andrew Senior, and Francoise Beaufays.
2839    "Long short-term memory recurrent neural network architectures for
2840    large scale acoustic modeling." INTERSPEECH, 2014.
2841  """
2842
2843  def __init__(self,
2844               num_units,
2845               norm=True,
2846               use_peepholes=False,
2847               cell_clip=None,
2848               initializer=None,
2849               num_proj=None,
2850               proj_clip=None,
2851               forget_bias=1,
2852               activation=None,
2853               reuse=None):
2854    """Initialize the parameters of a weight-normalized LSTM cell.
2855
2856    Args:
2857      num_units: int, The number of units in the LSTM cell
2858      norm: If `True`, apply normalization to the weight matrices. If False,
2859        the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
2860      use_peepholes: bool, set `True` to enable diagonal/peephole connections.
2861      cell_clip: (optional) A float value, if provided the cell state is clipped
2862        by this value prior to the cell output activation.
2863      initializer: (optional) The initializer to use for the weight matrices.
2864      num_proj: (optional) int, The output dimensionality for the projection
2865        matrices.  If None, no projection is performed.
2866      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
2867        provided, then the projected values are clipped elementwise to within
2868        `[-proj_clip, proj_clip]`.
2869      forget_bias: Biases of the forget gate are initialized by default to 1
2870        in order to reduce the scale of forgetting at the beginning of
2871        the training.
2872      activation: Activation function of the inner states.  Default: `tanh`.
2873      reuse: (optional) Python boolean describing whether to reuse variables
2874        in an existing scope.  If not `True`, and the existing scope already has
2875        the given variables, an error is raised.
2876    """
2877    super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
2878
2879    self._scope = "wn_lstm_cell"
2880    self._num_units = num_units
2881    self._norm = norm
2882    self._initializer = initializer
2883    self._use_peepholes = use_peepholes
2884    self._cell_clip = cell_clip
2885    self._num_proj = num_proj
2886    self._proj_clip = proj_clip
2887    self._activation = activation or math_ops.tanh
2888    self._forget_bias = forget_bias
2889
2890    self._weights_variable_name = "kernel"
2891    self._bias_variable_name = "bias"
2892
2893    if num_proj:
2894      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2895      self._output_size = num_proj
2896    else:
2897      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2898      self._output_size = num_units
2899
2900  @property
2901  def state_size(self):
2902    return self._state_size
2903
2904  @property
2905  def output_size(self):
2906    return self._output_size
2907
2908  def _normalize(self, weight, name):
2909    """Apply weight normalization.
2910
2911    Args:
2912      weight: a 2D tensor with known number of columns.
2913      name: string, variable name for the normalizer.
2914    Returns:
2915      A tensor with the same shape as `weight`.
2916    """
2917
2918    output_size = weight.get_shape().as_list()[1]
2919    g = vs.get_variable(name, [output_size], dtype=weight.dtype)
2920    return nn_impl.l2_normalize(weight, axis=0) * g
2921
2922  def _linear(self,
2923              args,
2924              output_size,
2925              norm,
2926              bias,
2927              bias_initializer=None,
2928              kernel_initializer=None):
2929    """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
2930
2931    Args:
2932      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2933      output_size: int, second dimension of W[i].
2934      norm: bool, whether to normalize the weights.
2935      bias: boolean, whether to add a bias term or not.
2936      bias_initializer: starting value to initialize the bias
2937        (default is all zeros).
2938      kernel_initializer: starting value to initialize the weight.
2939
2940    Returns:
2941      A 2D Tensor with shape [batch x output_size] equal to
2942      sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
2943
2944    Raises:
2945      ValueError: if some of the arguments has unspecified or wrong shape.
2946    """
2947    if args is None or (nest.is_sequence(args) and not args):
2948      raise ValueError("`args` must be specified")
2949    if not nest.is_sequence(args):
2950      args = [args]
2951
2952    # Calculate the total size of arguments on dimension 1.
2953    total_arg_size = 0
2954    shapes = [a.get_shape() for a in args]
2955    for shape in shapes:
2956      if shape.ndims != 2:
2957        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2958      if tensor_shape.dimension_value(shape[1]) is None:
2959        raise ValueError("linear expects shape[1] to be provided for shape %s, "
2960                         "but saw %s" % (shape, shape[1]))
2961      else:
2962        total_arg_size += tensor_shape.dimension_value(shape[1])
2963
2964    dtype = [a.dtype for a in args][0]
2965
2966    # Now the computation.
2967    scope = vs.get_variable_scope()
2968    with vs.variable_scope(scope) as outer_scope:
2969      weights = vs.get_variable(
2970          self._weights_variable_name, [total_arg_size, output_size],
2971          dtype=dtype,
2972          initializer=kernel_initializer)
2973      if norm:
2974        wn = []
2975        st = 0
2976        with ops.control_dependencies(None):
2977          for i in range(len(args)):
2978            en = st + tensor_shape.dimension_value(shapes[i][1])
2979            wn.append(
2980                self._normalize(weights[st:en, :], name="norm_{}".format(i)))
2981            st = en
2982
2983          weights = array_ops.concat(wn, axis=0)
2984
2985      if len(args) == 1:
2986        res = math_ops.matmul(args[0], weights)
2987      else:
2988        res = math_ops.matmul(array_ops.concat(args, 1), weights)
2989      if not bias:
2990        return res
2991
2992      with vs.variable_scope(outer_scope) as inner_scope:
2993        inner_scope.set_partitioner(None)
2994        if bias_initializer is None:
2995          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2996
2997        biases = vs.get_variable(
2998            self._bias_variable_name, [output_size],
2999            dtype=dtype,
3000            initializer=bias_initializer)
3001
3002      return nn_ops.bias_add(res, biases)
3003
3004  def call(self, inputs, state):
3005    """Run one step of LSTM.
3006
3007    Args:
3008      inputs: input Tensor, 2D, batch x num_units.
3009      state: A tuple of state Tensors, both `2-D`, with column sizes
3010       `c_state` and `m_state`.
3011
3012    Returns:
3013      A tuple containing:
3014
3015      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
3016        LSTM after reading `inputs` when previous state was `state`.
3017        Here output_dim is:
3018           num_proj if num_proj was set,
3019           num_units otherwise.
3020      - Tensor(s) representing the new state of LSTM after reading `inputs` when
3021        the previous state was `state`.  Same type and shape(s) as `state`.
3022
3023    Raises:
3024      ValueError: If input size cannot be inferred from inputs via
3025        static shape inference.
3026    """
3027    dtype = inputs.dtype
3028    num_units = self._num_units
3029    sigmoid = math_ops.sigmoid
3030    c, h = state
3031
3032    input_size = inputs.get_shape().with_rank(2).dims[1]
3033    if input_size.value is None:
3034      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
3035
3036    with vs.variable_scope(self._scope, initializer=self._initializer):
3037
3038      concat = self._linear(
3039          [inputs, h], 4 * num_units, norm=self._norm, bias=True)
3040
3041      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
3042      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
3043
3044      if self._use_peepholes:
3045        w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
3046        w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
3047        w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
3048
3049        new_c = (
3050            c * sigmoid(f + self._forget_bias + w_f_diag * c) +
3051            sigmoid(i + w_i_diag * c) * self._activation(j))
3052      else:
3053        new_c = (
3054            c * sigmoid(f + self._forget_bias) +
3055            sigmoid(i) * self._activation(j))
3056
3057      if self._cell_clip is not None:
3058        # pylint: disable=invalid-unary-operand-type
3059        new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
3060        # pylint: enable=invalid-unary-operand-type
3061      if self._use_peepholes:
3062        new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
3063      else:
3064        new_h = sigmoid(o) * self._activation(new_c)
3065
3066      if self._num_proj is not None:
3067        with vs.variable_scope("projection"):
3068          new_h = self._linear(
3069              new_h, self._num_proj, norm=self._norm, bias=False)
3070
3071        if self._proj_clip is not None:
3072          # pylint: disable=invalid-unary-operand-type
3073          new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
3074                                         self._proj_clip)
3075          # pylint: enable=invalid-unary-operand-type
3076
3077      new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
3078      return new_h, new_state
3079
3080
3081class IndRNNCell(rnn_cell_impl.LayerRNNCell):
3082  """Independently Recurrent Neural Network (IndRNN) cell
3083    (cf. https://arxiv.org/abs/1803.04831).
3084
3085  Args:
3086    num_units: int, The number of units in the RNN cell.
3087    activation: Nonlinearity to use.  Default: `tanh`.
3088    reuse: (optional) Python boolean describing whether to reuse variables
3089     in an existing scope.  If not `True`, and the existing scope already has
3090     the given variables, an error is raised.
3091    name: String, the name of the layer. Layers with the same name will
3092      share weights, but to avoid mistakes we require reuse=True in such
3093      cases.
3094    dtype: Default dtype of the layer (default of `None` means use the type
3095      of the first input). Required when `build` is called before `call`.
3096  """
3097
3098  def __init__(self,
3099               num_units,
3100               activation=None,
3101               reuse=None,
3102               name=None,
3103               dtype=None):
3104    super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
3105
3106    # Inputs must be 2-dimensional.
3107    self.input_spec = input_spec.InputSpec(ndim=2)
3108
3109    self._num_units = num_units
3110    self._activation = activation or math_ops.tanh
3111
3112  @property
3113  def state_size(self):
3114    return self._num_units
3115
3116  @property
3117  def output_size(self):
3118    return self._num_units
3119
3120  def build(self, inputs_shape):
3121    if tensor_shape.dimension_value(inputs_shape[1]) is None:
3122      raise ValueError(
3123          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
3124
3125    input_depth = tensor_shape.dimension_value(inputs_shape[1])
3126    # pylint: disable=protected-access
3127    self._kernel_w = self.add_variable(
3128        "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3129        shape=[input_depth, self._num_units])
3130    self._kernel_u = self.add_variable(
3131        "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3132        shape=[1, self._num_units],
3133        initializer=init_ops.random_uniform_initializer(
3134            minval=-1, maxval=1, dtype=self.dtype))
3135    self._bias = self.add_variable(
3136        rnn_cell_impl._BIAS_VARIABLE_NAME,
3137        shape=[self._num_units],
3138        initializer=init_ops.zeros_initializer(dtype=self.dtype))
3139    # pylint: enable=protected-access
3140
3141    self.built = True
3142
3143  def call(self, inputs, state):
3144    """IndRNN: output = new_state = act(W * input + u * state + B)."""
3145
3146    gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
3147        state * self._kernel_u)
3148    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
3149    output = self._activation(gate_inputs)
3150    return output, output
3151
3152
3153class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
3154  r"""Independently Gated Recurrent Unit cell.
3155
3156  Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
3157  yet with the \\(U_r\\), \\(U_z\\), and \\(U\\) matrices in equations 5, 6, and
3158  8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
3159  matrices, i.e. a Hadamard product with a single vector:
3160
3161    $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
3162      [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
3163    $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
3164      [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
3165    $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
3166      [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$
3167
3168  where \\(\circ\\) denotes the Hadamard operator. This means that each IndyGRU
3169  node sees only its own state, as opposed to seeing all states in the same
3170  layer.
3171
3172  Args:
3173    num_units: int, The number of units in the GRU cell.
3174    activation: Nonlinearity to use.  Default: `tanh`.
3175    reuse: (optional) Python boolean describing whether to reuse variables
3176     in an existing scope.  If not `True`, and the existing scope already has
3177     the given variables, an error is raised.
3178    kernel_initializer: (optional) The initializer to use for the weight
3179      matrices applied to the input.
3180    bias_initializer: (optional) The initializer to use for the bias.
3181    name: String, the name of the layer. Layers with the same name will
3182      share weights, but to avoid mistakes we require reuse=True in such
3183      cases.
3184    dtype: Default dtype of the layer (default of `None` means use the type
3185      of the first input). Required when `build` is called before `call`.
3186  """
3187
3188  def __init__(self,
3189               num_units,
3190               activation=None,
3191               reuse=None,
3192               kernel_initializer=None,
3193               bias_initializer=None,
3194               name=None,
3195               dtype=None):
3196    super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
3197
3198    # Inputs must be 2-dimensional.
3199    self.input_spec = input_spec.InputSpec(ndim=2)
3200
3201    self._num_units = num_units
3202    self._activation = activation or math_ops.tanh
3203    self._kernel_initializer = kernel_initializer
3204    self._bias_initializer = bias_initializer
3205
3206  @property
3207  def state_size(self):
3208    return self._num_units
3209
3210  @property
3211  def output_size(self):
3212    return self._num_units
3213
3214  def build(self, inputs_shape):
3215    if tensor_shape.dimension_value(inputs_shape[1]) is None:
3216      raise ValueError(
3217          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
3218
3219    input_depth = tensor_shape.dimension_value(inputs_shape[1])
3220    # pylint: disable=protected-access
3221    self._gate_kernel_w = self.add_variable(
3222        "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3223        shape=[input_depth, 2 * self._num_units],
3224        initializer=self._kernel_initializer)
3225    self._gate_kernel_u = self.add_variable(
3226        "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3227        shape=[1, 2 * self._num_units],
3228        initializer=init_ops.random_uniform_initializer(
3229            minval=-1, maxval=1, dtype=self.dtype))
3230    self._gate_bias = self.add_variable(
3231        "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
3232        shape=[2 * self._num_units],
3233        initializer=(self._bias_initializer
3234                     if self._bias_initializer is not None else
3235                     init_ops.constant_initializer(1.0, dtype=self.dtype)))
3236    self._candidate_kernel_w = self.add_variable(
3237        "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3238        shape=[input_depth, self._num_units],
3239        initializer=self._kernel_initializer)
3240    self._candidate_kernel_u = self.add_variable(
3241        "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3242        shape=[1, self._num_units],
3243        initializer=init_ops.random_uniform_initializer(
3244            minval=-1, maxval=1, dtype=self.dtype))
3245    self._candidate_bias = self.add_variable(
3246        "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
3247        shape=[self._num_units],
3248        initializer=(self._bias_initializer
3249                     if self._bias_initializer is not None else
3250                     init_ops.zeros_initializer(dtype=self.dtype)))
3251    # pylint: enable=protected-access
3252
3253    self.built = True
3254
3255  def call(self, inputs, state):
3256    """Recurrently independent Gated Recurrent Unit (GRU) with nunits cells."""
3257
3258    gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
3259        gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
3260    gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
3261
3262    value = math_ops.sigmoid(gate_inputs)
3263    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
3264
3265    r_state = r * state
3266
3267    candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
3268        r_state * self._candidate_kernel_u)
3269    candidate = nn_ops.bias_add(candidate, self._candidate_bias)
3270
3271    c = self._activation(candidate)
3272    new_h = u * state + (1 - u) * c
3273    return new_h, new_h
3274
3275
3276class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
3277  r"""Basic IndyLSTM recurrent network cell.
3278
3279  Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
3280  BasicLSTMCell, yet with the \\(U_f\\), \\(U_i\\), \\(U_o\\) and \\(U_c\\)
3281  matrices in the regular LSTM equations replaced by diagonal matrices, i.e. a
3282  Hadamard product with a single vector:
3283
3284    $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
3285    $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
3286    $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
3287    $$c_t = f_t \circ c_{t-1} +
3288            i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$
3289
3290  where \\(\circ\\) denotes the Hadamard operator. This means that each IndyLSTM
3291  node sees only its own state \\(h\\) and \\(c\\), as opposed to seeing all
3292  states in the same layer.
3293
3294  We add forget_bias (default: 1) to the biases of the forget gate in order to
3295  reduce the scale of forgetting in the beginning of the training.
3296
3297  It does not allow cell clipping, a projection layer, and does not
3298  use peep-hole connections: it is the basic baseline.
3299  """
3300
3301  def __init__(self,
3302               num_units,
3303               forget_bias=1.0,
3304               activation=None,
3305               reuse=None,
3306               kernel_initializer=None,
3307               bias_initializer=None,
3308               name=None,
3309               dtype=None):
3310    """Initialize the IndyLSTM cell.
3311
3312    Args:
3313      num_units: int, The number of units in the LSTM cell.
3314      forget_bias: float, The bias added to forget gates (see above).
3315        Must set to `0.0` manually when restoring from CudnnLSTM-trained
3316        checkpoints.
3317      activation: Activation function of the inner states.  Default: `tanh`.
3318      reuse: (optional) Python boolean describing whether to reuse variables
3319        in an existing scope.  If not `True`, and the existing scope already has
3320        the given variables, an error is raised.
3321      kernel_initializer: (optional) The initializer to use for the weight
3322        matrix applied to the inputs.
3323      bias_initializer: (optional) The initializer to use for the bias.
3324      name: String, the name of the layer. Layers with the same name will
3325        share weights, but to avoid mistakes we require reuse=True in such
3326        cases.
3327      dtype: Default dtype of the layer (default of `None` means use the type
3328        of the first input). Required when `build` is called before `call`.
3329    """
3330    super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
3331
3332    # Inputs must be 2-dimensional.
3333    self.input_spec = input_spec.InputSpec(ndim=2)
3334
3335    self._num_units = num_units
3336    self._forget_bias = forget_bias
3337    self._activation = activation or math_ops.tanh
3338    self._kernel_initializer = kernel_initializer
3339    self._bias_initializer = bias_initializer
3340
3341  @property
3342  def state_size(self):
3343    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
3344
3345  @property
3346  def output_size(self):
3347    return self._num_units
3348
3349  def build(self, inputs_shape):
3350    if tensor_shape.dimension_value(inputs_shape[1]) is None:
3351      raise ValueError(
3352          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
3353
3354    input_depth = tensor_shape.dimension_value(inputs_shape[1])
3355    # pylint: disable=protected-access
3356    self._kernel_w = self.add_variable(
3357        "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3358        shape=[input_depth, 4 * self._num_units],
3359        initializer=self._kernel_initializer)
3360    self._kernel_u = self.add_variable(
3361        "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3362        shape=[1, 4 * self._num_units],
3363        initializer=init_ops.random_uniform_initializer(
3364            minval=-1, maxval=1, dtype=self.dtype))
3365    self._bias = self.add_variable(
3366        rnn_cell_impl._BIAS_VARIABLE_NAME,
3367        shape=[4 * self._num_units],
3368        initializer=(self._bias_initializer
3369                     if self._bias_initializer is not None else
3370                     init_ops.zeros_initializer(dtype=self.dtype)))
3371    # pylint: enable=protected-access
3372
3373    self.built = True
3374
3375  def call(self, inputs, state):
3376    """Independent Long short-term memory cell (IndyLSTM).
3377
3378    Args:
3379      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
3380      state: An `LSTMStateTuple` of state tensors, each shaped
3381        `[batch_size, num_units]`.
3382
3383    Returns:
3384      A pair containing the new hidden state, and the new state (a
3385        `LSTMStateTuple`).
3386    """
3387    sigmoid = math_ops.sigmoid
3388    one = constant_op.constant(1, dtype=dtypes.int32)
3389    c, h = state
3390
3391    gate_inputs = math_ops.matmul(inputs, self._kernel_w)
3392    gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
3393    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
3394
3395    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
3396    i, j, f, o = array_ops.split(
3397        value=gate_inputs, num_or_size_splits=4, axis=one)
3398
3399    forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
3400    # Note that using `add` and `multiply` instead of `+` and `*` gives a
3401    # performance improvement. So using those at the cost of readability.
3402    add = math_ops.add
3403    multiply = math_ops.multiply
3404    new_c = add(
3405        multiply(c, sigmoid(add(f, forget_bias_tensor))),
3406        multiply(sigmoid(i), self._activation(j)))
3407    new_h = multiply(self._activation(new_c), sigmoid(o))
3408
3409    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
3410    return new_h, new_state
3411
3412
3413NTMControllerState = collections.namedtuple(
3414    "NTMControllerState",
3415    ("controller_state", "read_vector_list", "w_list", "M", "time"))
3416
3417
3418class NTMCell(rnn_cell_impl.LayerRNNCell):
3419  """Neural Turing Machine Cell with RNN controller.
3420
3421    Implementation based on:
3422    https://arxiv.org/abs/1807.08518
3423    Mark Collier, Joeran Beel
3424
3425    which is in turn based on the source code of:
3426    https://github.com/snowkylin/ntm
3427
3428    and of course the original NTM paper:
3429    Neural Turing Machines
3430    https://arxiv.org/abs/1410.5401
3431    A Graves, G Wayne, I Danihelka
3432  """
3433
3434  def __init__(self,
3435               controller,
3436               memory_size,
3437               memory_vector_dim,
3438               read_head_num,
3439               write_head_num,
3440               shift_range=1,
3441               output_dim=None,
3442               clip_value=20,
3443               dtype=dtypes.float32,
3444               name=None):
3445    """Initialize the NTM Cell.
3446
3447      Args:
3448        controller: an RNNCell, the RNN controller.
3449        memory_size: int, The number of memory locations in the NTM memory
3450          matrix
3451        memory_vector_dim: int, The dimensionality of each location in the NTM
3452          memory matrix
3453        read_head_num: int, The number of read heads from the controller into
3454          memory
3455        write_head_num: int, The number of write heads from the controller into
3456          memory
3457        shift_range: int, The number of places to the left/right it is possible
3458          to iterate the previous address to in a single step
3459        output_dim: int, The number of dimensions to make a linear projection of
3460          the NTM controller outputs to. If None, no linear projection is
3461          applied
3462        clip_value: float, The maximum absolute value the controller parameters
3463          are clipped to
3464        dtype: Default dtype of the layer (default of `None` means use the type
3465          of the first input). Required when `build` is called before `call`.
3466        name: String, the name of the layer. Layers with the same name will
3467          share weights, but to avoid mistakes we require reuse=True in such
3468          cases.
3469    """
3470    super(NTMCell, self).__init__(dtype=dtype, name=name)
3471
3472    rnn_cell_impl.assert_like_rnncell("NTM RNN controller cell", controller)
3473
3474    self.controller = controller
3475    self.memory_size = memory_size
3476    self.memory_vector_dim = memory_vector_dim
3477    self.read_head_num = read_head_num
3478    self.write_head_num = write_head_num
3479    self.clip_value = clip_value
3480
3481    self.output_dim = output_dim
3482    self.shift_range = shift_range
3483
3484    self.num_parameters_per_head = (
3485        self.memory_vector_dim + 2 * self.shift_range + 4)
3486    self.num_heads = self.read_head_num + self.write_head_num
3487    self.total_parameter_num = (
3488        self.num_parameters_per_head * self.num_heads +
3489        self.memory_vector_dim * 2 * self.write_head_num)
3490
3491  @property
3492  def state_size(self):
3493    return NTMControllerState(
3494        controller_state=self.controller.state_size,
3495        read_vector_list=[
3496            self.memory_vector_dim for _ in range(self.read_head_num)
3497        ],
3498        w_list=[
3499            self.memory_size
3500            for _ in range(self.read_head_num + self.write_head_num)
3501        ],
3502        M=tensor_shape.TensorShape([self.memory_size * self.memory_vector_dim]),
3503        time=tensor_shape.TensorShape([]))
3504
3505  @property
3506  def output_size(self):
3507    return self.output_dim
3508
3509  def build(self, inputs_shape):
3510    if self.output_dim is None:
3511      if inputs_shape[1].value is None:
3512        raise ValueError(
3513            "Expected inputs.shape[-1] to be known, saw shape: %s" %
3514            inputs_shape)
3515      else:
3516        self.output_dim = inputs_shape[1].value
3517
3518    def _create_linear_initializer(input_size, dtype=dtypes.float32):
3519      stddev = 1.0 / math.sqrt(input_size)
3520      return init_ops.truncated_normal_initializer(stddev=stddev, dtype=dtype)
3521
3522    self._params_kernel = self.add_variable(
3523        "parameters_kernel",
3524        shape=[self.controller.output_size, self.total_parameter_num],
3525        initializer=_create_linear_initializer(self.controller.output_size))
3526
3527    self._params_bias = self.add_variable(
3528        "parameters_bias",
3529        shape=[self.total_parameter_num],
3530        initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
3531
3532    self._output_kernel = self.add_variable(
3533        "output_kernel",
3534        shape=[
3535            self.controller.output_size +
3536            self.memory_vector_dim * self.read_head_num, self.output_dim
3537        ],
3538        initializer=_create_linear_initializer(self.controller.output_size +
3539                                               self.memory_vector_dim *
3540                                               self.read_head_num))
3541
3542    self._output_bias = self.add_variable(
3543        "output_bias",
3544        shape=[self.output_dim],
3545        initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
3546
3547    self._init_read_vectors = [
3548        self.add_variable(
3549            "initial_read_vector_%d" % i,
3550            shape=[1, self.memory_vector_dim],
3551            initializer=initializers.glorot_uniform())
3552        for i in range(self.read_head_num)
3553    ]
3554
3555    self._init_address_weights = [
3556        self.add_variable(
3557            "initial_address_weights_%d" % i,
3558            shape=[1, self.memory_size],
3559            initializer=initializers.glorot_uniform())
3560        for i in range(self.read_head_num + self.write_head_num)
3561    ]
3562
3563    self._M = self.add_variable(
3564        "memory",
3565        shape=[self.memory_size, self.memory_vector_dim],
3566        initializer=init_ops.constant_initializer(1e-6, dtype=self.dtype))
3567
3568    self.built = True
3569
3570  def call(self, x, prev_state):
3571    # Addressing Mechanisms (Sec 3.3)
3572
3573    def _prev_read_vector_list_initial_value():
3574      return [
3575          self._expand(
3576              math_ops.tanh(
3577                  array_ops.squeeze(
3578                      math_ops.matmul(
3579                          array_ops.ones([1, 1]), self._init_read_vectors[i]))),
3580              dim=0,
3581              N=x.shape[0].value or array_ops.shape(x)[0])
3582          for i in range(self.read_head_num)
3583      ]
3584
3585    prev_read_vector_list = control_flow_ops.cond(
3586        math_ops.equal(prev_state.time,
3587                       0), _prev_read_vector_list_initial_value, lambda:
3588        prev_state.read_vector_list)
3589    if self.read_head_num == 1:
3590      prev_read_vector_list = [prev_read_vector_list]
3591
3592    controller_input = array_ops.concat([x] + prev_read_vector_list, axis=1)
3593    controller_output, controller_state = self.controller(
3594        controller_input, prev_state.controller_state)
3595
3596    parameters = math_ops.matmul(controller_output, self._params_kernel)
3597    parameters = nn_ops.bias_add(parameters, self._params_bias)
3598    parameters = clip_ops.clip_by_value(parameters, -self.clip_value,
3599                                        self.clip_value)
3600    head_parameter_list = array_ops.split(
3601        parameters[:, :self.num_parameters_per_head * self.num_heads],
3602        self.num_heads,
3603        axis=1)
3604    erase_add_list = array_ops.split(
3605        parameters[:, self.num_parameters_per_head * self.num_heads:],
3606        2 * self.write_head_num,
3607        axis=1)
3608
3609    def _prev_w_list_initial_value():
3610      return [
3611          self._expand(
3612              nn_ops.softmax(
3613                  array_ops.squeeze(
3614                      math_ops.matmul(
3615                          array_ops.ones([1, 1]),
3616                          self._init_address_weights[i]))),
3617              dim=0,
3618              N=x.shape[0].value or array_ops.shape(x)[0])
3619          for i in range(self.read_head_num + self.write_head_num)
3620      ]
3621
3622    prev_w_list = control_flow_ops.cond(
3623        math_ops.equal(prev_state.time, 0),
3624        _prev_w_list_initial_value, lambda: prev_state.w_list)
3625    if (self.read_head_num + self.write_head_num) == 1:
3626      prev_w_list = [prev_w_list]
3627
3628    prev_M = control_flow_ops.cond(
3629        math_ops.equal(prev_state.time, 0), lambda: self._expand(
3630            self._M, dim=0, N=x.shape[0].value or array_ops.shape(x)[0]),
3631        lambda: prev_state.M)
3632
3633    w_list = []
3634    for i, head_parameter in enumerate(head_parameter_list):
3635      k = math_ops.tanh(head_parameter[:, 0:self.memory_vector_dim])
3636      beta = nn_ops.softplus(head_parameter[:, self.memory_vector_dim])
3637      g = math_ops.sigmoid(head_parameter[:, self.memory_vector_dim + 1])
3638      s = nn_ops.softmax(head_parameter[:, self.memory_vector_dim +
3639                                        2:(self.memory_vector_dim + 2 +
3640                                           (self.shift_range * 2 + 1))])
3641      gamma = nn_ops.softplus(head_parameter[:, -1]) + 1
3642      w = self._addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i])
3643      w_list.append(w)
3644
3645    # Reading (Sec 3.1)
3646
3647    read_w_list = w_list[:self.read_head_num]
3648    read_vector_list = []
3649    for i in range(self.read_head_num):
3650      read_vector = math_ops.reduce_sum(
3651          array_ops.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1)
3652      read_vector_list.append(read_vector)
3653
3654    # Writing (Sec 3.2)
3655
3656    write_w_list = w_list[self.read_head_num:]
3657    M = prev_M
3658    for i in range(self.write_head_num):
3659      w = array_ops.expand_dims(write_w_list[i], axis=2)
3660      erase_vector = array_ops.expand_dims(
3661          math_ops.sigmoid(erase_add_list[i * 2]), axis=1)
3662      add_vector = array_ops.expand_dims(
3663          math_ops.tanh(erase_add_list[i * 2 + 1]), axis=1)
3664      erase_M = array_ops.ones_like(M) - math_ops.matmul(w, erase_vector)
3665      M = M * erase_M + math_ops.matmul(w, add_vector)
3666
3667    output = math_ops.matmul(
3668        array_ops.concat([controller_output] + read_vector_list, axis=1),
3669        self._output_kernel)
3670    output = nn_ops.bias_add(output, self._output_bias)
3671    output = clip_ops.clip_by_value(output, -self.clip_value, self.clip_value)
3672
3673    return output, NTMControllerState(
3674        controller_state=controller_state,
3675        read_vector_list=read_vector_list,
3676        w_list=w_list,
3677        M=M,
3678        time=prev_state.time + 1)
3679
3680  def _expand(self, x, dim, N):
3681    return array_ops.concat([array_ops.expand_dims(x, dim) for _ in range(N)],
3682                            axis=dim)
3683
3684  def _addressing(self, k, beta, g, s, gamma, prev_M, prev_w):
3685    # Sec 3.3.1 Focusing by Content
3686
3687    k = array_ops.expand_dims(k, axis=2)
3688    inner_product = math_ops.matmul(prev_M, k)
3689    k_norm = math_ops.sqrt(
3690        math_ops.reduce_sum(math_ops.square(k), axis=1, keepdims=True))
3691    M_norm = math_ops.sqrt(
3692        math_ops.reduce_sum(math_ops.square(prev_M), axis=2, keepdims=True))
3693    norm_product = M_norm * k_norm
3694
3695    # eq (6)
3696    K = array_ops.squeeze(inner_product / (norm_product + 1e-8))
3697
3698    K_amplified = math_ops.exp(array_ops.expand_dims(beta, axis=1) * K)
3699
3700    # eq (5)
3701    w_c = K_amplified / math_ops.reduce_sum(K_amplified, axis=1, keepdims=True)
3702
3703    # Sec 3.3.2 Focusing by Location
3704
3705    g = array_ops.expand_dims(g, axis=1)
3706
3707    # eq (7)
3708    w_g = g * w_c + (1 - g) * prev_w
3709
3710    s = array_ops.concat([
3711        s[:, :self.shift_range + 1],
3712        array_ops.zeros([
3713            s.shape[0].value or array_ops.shape(s)[0], self.memory_size -
3714            (self.shift_range * 2 + 1)
3715        ]), s[:, -self.shift_range:]
3716    ],
3717                         axis=1)
3718    t = array_ops.concat(
3719        [array_ops.reverse(s, axis=[1]),
3720         array_ops.reverse(s, axis=[1])],
3721        axis=1)
3722    s_matrix = array_ops.stack([
3723        t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1]
3724        for i in range(self.memory_size)
3725    ],
3726                               axis=1)
3727
3728    # eq (8)
3729    w_ = math_ops.reduce_sum(
3730        array_ops.expand_dims(w_g, axis=1) * s_matrix, axis=2)
3731    w_sharpen = math_ops.pow(w_, array_ops.expand_dims(gamma, axis=1))
3732
3733    # eq (9)
3734    w = w_sharpen / math_ops.reduce_sum(w_sharpen, axis=1, keepdims=True)
3735
3736    return w
3737
3738  def zero_state(self, batch_size, dtype):
3739    read_vector_list = [
3740        array_ops.zeros([batch_size, self.memory_vector_dim])
3741        for _ in range(self.read_head_num)
3742    ]
3743
3744    w_list = [
3745        array_ops.zeros([batch_size, self.memory_size])
3746        for _ in range(self.read_head_num + self.write_head_num)
3747    ]
3748
3749    controller_init_state = self.controller.zero_state(batch_size, dtype)
3750
3751    M = array_ops.zeros([batch_size, self.memory_size, self.memory_vector_dim])
3752
3753    return NTMControllerState(
3754        controller_state=controller_init_state,
3755        read_vector_list=read_vector_list,
3756        w_list=w_list,
3757        M=M,
3758        time=0)
3759
3760
3761class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
3762  """MinimalRNN cell.
3763
3764  The implementation is based on:
3765
3766    https://arxiv.org/pdf/1806.05394v2.pdf
3767
3768  Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz.
3769  "Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal
3770   Propagation in Recurrent Neural Networks." ICML, 2018.
3771
3772  A MinimalRNN cell first projects the input to the hidden space. The new
3773  hidden state is then calculated as a weighted sum of the projected input and
3774  the previous hidden state, using a single update gate.
3775  """
3776
3777  def __init__(self,
3778               units,
3779               activation="tanh",
3780               kernel_initializer="glorot_uniform",
3781               bias_initializer="ones",
3782               name=None,
3783               dtype=None,
3784               **kwargs):
3785    """Initialize the parameters for a MinimalRNN cell.
3786
3787    Args:
3788      units: int, The number of units in the MinimalRNN cell.
3789      activation: Nonlinearity to use in the feedforward network. Default:
3790        `tanh`.
3791      kernel_initializer: The initializer to use for the weight in the update
3792        gate and feedforward network. Default: `glorot_uniform`.
3793      bias_initializer: The initializer to use for the bias in the update
3794        gate. Default: `ones`.
3795      name: String, the name of the cell.
3796      dtype: Default dtype of the cell.
3797      **kwargs: Dict, keyword named properties for common cell attributes.
3798    """
3799    super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs)
3800
3801    # Inputs must be 2-dimensional.
3802    self.input_spec = input_spec.InputSpec(ndim=2)
3803
3804    self.units = units
3805    self.activation = activations.get(activation)
3806    self.kernel_initializer = initializers.get(kernel_initializer)
3807    self.bias_initializer = initializers.get(bias_initializer)
3808
3809  @property
3810  def state_size(self):
3811    return self.units
3812
3813  @property
3814  def output_size(self):
3815    return self.units
3816
3817  def build(self, inputs_shape):
3818    if inputs_shape[-1] is None:
3819      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
3820                       % str(inputs_shape))
3821
3822    input_size = inputs_shape[-1]
3823    # pylint: disable=protected-access
3824    # self._kernel contains W_x, W, V
3825    self.kernel = self.add_weight(
3826        name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3827        shape=[input_size + 2 * self.units, self.units],
3828        initializer=self.kernel_initializer)
3829    self.bias = self.add_weight(
3830        name=rnn_cell_impl._BIAS_VARIABLE_NAME,
3831        shape=[self.units],
3832        initializer=self.bias_initializer)
3833    # pylint: enable=protected-access
3834
3835    self.built = True
3836
3837  def call(self, inputs, state):
3838    """Run one step of MinimalRNN.
3839
3840    Args:
3841      inputs: input Tensor, must be 2-D, `[batch, input_size]`.
3842      state: state Tensor, must be 2-D, `[batch, state_size]`.
3843
3844    Returns:
3845      A tuple containing:
3846
3847      - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
3848      - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
3849
3850    Raises:
3851      ValueError: If input size cannot be inferred from inputs via
3852        static shape inference.
3853    """
3854    input_size = inputs.get_shape()[1]
3855    if tensor_shape.dimension_value(input_size) is None:
3856      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
3857
3858    feedforward_weight, gate_weight = array_ops.split(
3859        value=self.kernel,
3860        num_or_size_splits=[tensor_shape.dimension_value(input_size),
3861                            2 * self.units],
3862        axis=0)
3863
3864    feedforward = math_ops.matmul(inputs, feedforward_weight)
3865    feedforward = self.activation(feedforward)
3866
3867    gate_inputs = math_ops.matmul(
3868        array_ops.concat([feedforward, state], 1), gate_weight)
3869    gate_inputs = nn_ops.bias_add(gate_inputs, self.bias)
3870    u = math_ops.sigmoid(gate_inputs)
3871
3872    new_h = u * state + (1 - u) * feedforward
3873    return new_h, new_h
3874
3875
3876class CFNCell(rnn_cell_impl.LayerRNNCell):
3877  """Chaos Free Network cell.
3878
3879  The implementation is based on:
3880
3881    https://openreview.net/pdf?id=S1dIzvclg
3882
3883  Thomas Laurent, James von Brecht.
3884  "A recurrent neural network without chaos." ICLR, 2017.
3885
3886  A CFN cell first projects the input to the hidden space. The hidden state
3887  goes through a contractive mapping. The new hidden state is then calculated
3888  as a linear combination of the projected input and the contracted previous
3889  hidden state, using decoupled input and forget gates.
3890  """
3891
3892  def __init__(self,
3893               units,
3894               activation="tanh",
3895               kernel_initializer="glorot_uniform",
3896               bias_initializer="ones",
3897               name=None,
3898               dtype=None,
3899               **kwargs):
3900    """Initialize the parameters for a CFN cell.
3901
3902    Args:
3903      units: int, The number of units in the CFN cell.
3904      activation: Nonlinearity to use. Default: `tanh`.
3905      kernel_initializer: Initializer for the `kernel` weights
3906        matrix. Default: `glorot_uniform`.
3907      bias_initializer: The initializer to use for the bias in the
3908        gates. Default: `ones`.
3909      name: String, the name of the cell.
3910      dtype: Default dtype of the cell.
3911      **kwargs: Dict, keyword named properties for common cell attributes.
3912    """
3913    super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
3914
3915    # Inputs must be 2-dimensional.
3916    self.input_spec = input_spec.InputSpec(ndim=2)
3917
3918    self.units = units
3919    self.activation = activations.get(activation)
3920    self.kernel_initializer = initializers.get(kernel_initializer)
3921    self.bias_initializer = initializers.get(bias_initializer)
3922
3923  @property
3924  def state_size(self):
3925    return self.units
3926
3927  @property
3928  def output_size(self):
3929    return self.units
3930
3931  def build(self, inputs_shape):
3932    if inputs_shape[-1] is None:
3933      raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
3934                       % str(inputs_shape))
3935
3936    input_size = inputs_shape[-1]
3937    # pylint: disable=protected-access
3938    # `self.kernel` contains V_{\theta}, V_{\eta}, W.
3939    # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}.
3940    # `self.bias` contains b_{\theta}, b_{\eta}.
3941    self.kernel = self.add_weight(
3942        shape=[input_size, 3 * self.units],
3943        name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3944        initializer=self.kernel_initializer)
3945    self.recurrent_kernel = self.add_weight(
3946        shape=[self.units, 2 * self.units],
3947        name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
3948        initializer=self.kernel_initializer)
3949    self.bias = self.add_weight(
3950        shape=[2 * self.units],
3951        name=rnn_cell_impl._BIAS_VARIABLE_NAME,
3952        initializer=self.bias_initializer)
3953    # pylint: enable=protected-access
3954
3955    self.built = True
3956
3957  def call(self, inputs, state):
3958    """Run one step of CFN.
3959
3960    Args:
3961      inputs: input Tensor, must be 2-D, `[batch, input_size]`.
3962      state: state Tensor, must be 2-D, `[batch, state_size]`.
3963
3964    Returns:
3965      A tuple containing:
3966
3967      - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
3968      - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
3969
3970    Raises:
3971      ValueError: If input size cannot be inferred from inputs via
3972        static shape inference.
3973    """
3974    input_size = inputs.get_shape()[-1]
3975    if tensor_shape.dimension_value(input_size) is None:
3976      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
3977
3978    # The variable names u, v, w, b are consistent with the notations in the
3979    # original paper.
3980    v, w = array_ops.split(
3981        value=self.kernel,
3982        num_or_size_splits=[2 * self.units, self.units],
3983        axis=1)
3984    u = self.recurrent_kernel
3985    b = self.bias
3986
3987    gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v)
3988    gates = nn_ops.bias_add(gates, b)
3989    gates = math_ops.sigmoid(gates)
3990    theta, eta = array_ops.split(value=gates,
3991                                 num_or_size_splits=2,
3992                                 axis=1)
3993
3994    proj_input = math_ops.matmul(inputs, w)
3995
3996    # The input gate is (1 - eta), which is different from the original paper.
3997    # This is for the propose of initialization. With the default
3998    # bias_initializer `ones`, the input gate is initialized to a small number.
3999    new_h = theta * self.activation(state) + (1 - eta) * self.activation(
4000        proj_input)
4001
4002    return new_h, new_h
4003