• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""LSTM Block Cell ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22import six
23
24from tensorflow.contrib.rnn.ops import gen_lstm_ops
25from tensorflow.contrib.util import loader
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.keras.engine import input_spec
29from tensorflow.python.layers import base as base_layer
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import init_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nn_ops
34from tensorflow.python.ops import rnn_cell_impl
35from tensorflow.python.platform import resource_loader
36
37_lstm_ops_so = loader.load_op_library(
38    resource_loader.get_path_to_datafile("_lstm_ops.so"))
39
40LayerRNNCell = rnn_cell_impl.LayerRNNCell  # pylint: disable=invalid-name
41
42
43# pylint: disable=invalid-name
44def _lstm_block_cell(x,
45                     cs_prev,
46                     h_prev,
47                     w,
48                     b,
49                     wci=None,
50                     wcf=None,
51                     wco=None,
52                     forget_bias=None,
53                     cell_clip=None,
54                     use_peephole=None,
55                     name=None):
56  r"""Computes the LSTM cell forward propagation for 1 time step.
57
58  This implementation uses 1 weight matrix and 1 bias vector, and there's an
59  optional peephole connection.
60
61  This kernel op implements the following mathematical equations:
62
63  ```python
64  xh = [x, h_prev]
65  [i, ci, f, o] = xh * w + b
66  f = f + forget_bias
67
68  if not use_peephole:
69    wci = wcf = wco = 0
70
71  i = sigmoid(cs_prev * wci + i)
72  f = sigmoid(cs_prev * wcf + f)
73  ci = tanh(ci)
74
75  cs = ci .* i + cs_prev .* f
76  cs = clip(cs, cell_clip)
77
78  o = sigmoid(cs * wco + o)
79  co = tanh(cs)
80  h = co .* o
81  ```
82
83  Args:
84    x: A `Tensor`. Must be one of the following types: `float32`.
85      The input to the LSTM cell, shape (batch_size, num_inputs).
86    cs_prev: A `Tensor`. Must have the same type as `x`.
87      Value of the cell state at previous time step.
88    h_prev: A `Tensor`. Must have the same type as `x`.
89      Output of the previous cell at previous time step.
90    w: A `Tensor`. Must have the same type as `x`. The weight matrix.
91    b: A `Tensor`. Must have the same type as `x`. The bias vector.
92    wci: A `Tensor`. Must have the same type as `x`.
93      The weight matrix for input gate peephole connection.
94    wcf: A `Tensor`. Must have the same type as `x`.
95      The weight matrix for forget gate peephole connection.
96    wco: A `Tensor`. Must have the same type as `x`.
97      The weight matrix for output gate peephole connection.
98    forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
99    cell_clip: An optional `float`. Defaults to `-1` (no clipping).
100      Value to clip the 'cs' value to. Disable by setting to negative value.
101    use_peephole: An optional `bool`. Defaults to `False`.
102      Whether to use peephole weights.
103    name: A name for the operation (optional).
104
105  Returns:
106    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
107    i: A `Tensor`. Has the same type as `x`. The input gate.
108    cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
109    f: A `Tensor`. Has the same type as `x`. The forget gate.
110    o: A `Tensor`. Has the same type as `x`. The output gate.
111    ci: A `Tensor`. Has the same type as `x`. The cell input.
112    co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
113    h: A `Tensor`. Has the same type as `x`. The output h vector.
114
115  Raises:
116    ValueError: If cell_size is None.
117  """
118  if wci is None:
119    cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
120    if cell_size is None:
121      raise ValueError("cell_size from `cs_prev` should not be None.")
122    wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
123    wcf = wci
124    wco = wci
125
126  # pylint: disable=protected-access
127  return gen_lstm_ops.lstm_block_cell(
128      x=x,
129      cs_prev=cs_prev,
130      h_prev=h_prev,
131      w=w,
132      wci=wci,
133      wcf=wcf,
134      wco=wco,
135      b=b,
136      forget_bias=forget_bias,
137      cell_clip=cell_clip if cell_clip is not None else -1,
138      use_peephole=use_peephole,
139      name=name)
140  # pylint: enable=protected-access
141
142
143def _block_lstm(seq_len_max,
144                x,
145                w,
146                b,
147                cs_prev=None,
148                h_prev=None,
149                wci=None,
150                wcf=None,
151                wco=None,
152                forget_bias=None,
153                cell_clip=None,
154                use_peephole=None,
155                name=None):
156  r"""TODO(williamchan): add doc.
157
158  Args:
159    seq_len_max: A `Tensor` of type `int64`.
160    x: A list of at least 1 `Tensor` objects of the same type.
161    w: A `Tensor`. Must have the same type as `x`.
162    b: A `Tensor`. Must have the same type as `x`.
163    cs_prev: A `Tensor`. Must have the same type as `x`.
164    h_prev: A `Tensor`. Must have the same type as `x`.
165    wci: A `Tensor`. Must have the same type as `x`.
166    wcf: A `Tensor`. Must have the same type as `x`.
167    wco: A `Tensor`. Must have the same type as `x`.
168    forget_bias: An optional `float`. Defaults to `1`.
169    cell_clip: An optional `float`. Defaults to `-1` (no clipping).
170    use_peephole: An optional `bool`. Defaults to `False`.
171    name: A name for the operation (optional).
172
173  Returns:
174    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
175    i: A list with the same number of `Tensor` objects as `x` of `Tensor`
176    objects of the same type as x.
177    cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
178    objects of the same type as x.
179    f: A list with the same number of `Tensor` objects as `x` of `Tensor`
180    objects of the same type as x.
181    o: A list with the same number of `Tensor` objects as `x` of `Tensor`
182    objects of the same type as x.
183    ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
184    objects of the same type as x.
185    co: A list with the same number of `Tensor` objects as `x` of `Tensor`
186    objects of the same type as x.
187    h: A list with the same number of `Tensor` objects as `x` of `Tensor`
188    objects of the same type as x.
189
190  Raises:
191    ValueError: If `b` does not have a valid shape.
192  """
193  dtype = x[0].dtype
194  batch_size = x[0].get_shape().with_rank(2).dims[0].value
195  cell_size4 = b.get_shape().with_rank(1).dims[0].value
196  if cell_size4 is None:
197    raise ValueError("`b` shape must not be None.")
198  cell_size = cell_size4 / 4
199  zero_state = None
200  if cs_prev is None or h_prev is None:
201    zero_state = array_ops.constant(
202        0, dtype=dtype, shape=[batch_size, cell_size])
203  if cs_prev is None:
204    cs_prev = zero_state
205  if h_prev is None:
206    h_prev = zero_state
207  if wci is None:
208    wci = array_ops.constant(0, dtype=dtype, shape=[cell_size])
209    wcf = wci
210    wco = wci
211
212  # pylint: disable=protected-access
213  i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm(
214      seq_len_max=seq_len_max,
215      x=array_ops.stack(x),
216      cs_prev=cs_prev,
217      h_prev=h_prev,
218      w=w,
219      wci=wci,
220      wcf=wcf,
221      wco=wco,
222      b=b,
223      forget_bias=forget_bias,
224      cell_clip=cell_clip if cell_clip is not None else -1,
225      name=name,
226      use_peephole=use_peephole)
227
228  return array_ops.unstack(i), array_ops.unstack(cs), array_ops.unstack(
229      f), array_ops.unstack(o), array_ops.unstack(ci), array_ops.unstack(
230          co), array_ops.unstack(h)
231  # pylint: enable=protected-access
232  # pylint: enable=invalid-name
233
234
235_lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"]
236
237
238@ops.RegisterGradient("LSTMBlockCell")
239def _LSTMBlockCellGrad(op, *grad):
240  """Gradient for LSTMBlockCell."""
241  (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs
242  (i, cs, f, o, ci, co, _) = op.outputs
243  (_, cs_grad, _, _, _, _, h_grad) = grad
244
245  batch_size = x.get_shape().with_rank(2).dims[0].value
246  if batch_size is None:
247    batch_size = -1
248  input_size = x.get_shape().with_rank(2).dims[1].value
249  if input_size is None:
250    raise ValueError("input_size from `x` should not be None.")
251  cell_size = cs_prev.get_shape().with_rank(2).dims[1].value
252  if cell_size is None:
253    raise ValueError("cell_size from `cs_prev` should not be None.")
254
255  (cs_prev_grad, dicfo, wci_grad, wcf_grad,
256   wco_grad) = gen_lstm_ops.lstm_block_cell_grad(
257       x,
258       cs_prev,
259       h_prev,
260       w,
261       wci,
262       wcf,
263       wco,
264       b,
265       i,
266       cs,
267       f,
268       o,
269       ci,
270       co,
271       cs_grad,
272       h_grad,
273       use_peephole=op.get_attr("use_peephole"))
274
275  # Backprop from dicfo to xh.
276  xh_grad = math_ops.matmul(dicfo, w, transpose_b=True)
277
278  x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size))
279  x_grad.get_shape().merge_with(x.get_shape())
280
281  h_prev_grad = array_ops.slice(xh_grad, (0, input_size),
282                                (batch_size, cell_size))
283  h_prev_grad.get_shape().merge_with(h_prev.get_shape())
284
285  # Backprop from dicfo to w.
286  xh = array_ops.concat([x, h_prev], 1)
287  w_grad = math_ops.matmul(xh, dicfo, transpose_a=True)
288  w_grad.get_shape().merge_with(w.get_shape())
289
290  # Backprop from dicfo to b.
291  b_grad = nn_ops.bias_add_grad(dicfo)
292  b_grad.get_shape().merge_with(b.get_shape())
293
294  return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
295          wco_grad, b_grad)
296
297
298@ops.RegisterGradient("BlockLSTM")
299def _BlockLSTMGrad(op, *grad):
300  """Gradient for BlockLSTM."""
301  seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs
302  i, cs, f, o, ci, co, h = op.outputs
303
304  cs_grad = grad[1]
305  h_grad = grad[6]
306
307  (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad,
308   b_grad) = gen_lstm_ops.block_lstm_grad(
309       seq_len_max,
310       x,
311       cs_prev,
312       h_prev,
313       w,
314       wci,
315       wcf,
316       wco,
317       b,
318       i,
319       cs,
320       f,
321       o,
322       ci,
323       co,
324       h,
325       cs_grad,
326       h_grad,
327       use_peephole=op.get_attr("use_peephole"))
328
329  return [
330      None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
331      wco_grad, b_grad
332  ]
333
334
335class LSTMBlockCell(LayerRNNCell):
336  """Basic LSTM recurrent network cell.
337
338  The implementation is based on: http://arxiv.org/abs/1409.2329.
339
340  We add `forget_bias` (default: 1) to the biases of the forget gate in order to
341  reduce the scale of forgetting in the beginning of the training.
342
343  Unlike `rnn_cell_impl.LSTMCell`, this is a monolithic op and should be much
344  faster.  The weight and bias matrices should be compatible as long as the
345  variable scope matches.
346  """
347
348  def __init__(self,
349               num_units,
350               forget_bias=1.0,
351               cell_clip=None,
352               use_peephole=False,
353               dtype=None,
354               reuse=None,
355               name="lstm_cell"):
356    """Initialize the basic LSTM cell.
357
358    Args:
359      num_units: int, The number of units in the LSTM cell.
360      forget_bias: float, The bias added to forget gates (see above).
361      cell_clip: An optional `float`. Defaults to `-1` (no clipping).
362      use_peephole: Whether to use peephole connections or not.
363      dtype: the variable dtype of this layer. Default to tf.float32.
364      reuse: (optional) boolean describing whether to reuse variables in an
365        existing scope.  If not `True`, and the existing scope already has the
366        given variables, an error is raised.
367      name: String, the name of the layer. Layers with the same name will
368        share weights, but to avoid mistakes we require reuse=True in such
369        cases.  By default this is "lstm_cell", for variable-name compatibility
370        with `tf.nn.rnn_cell.LSTMCell`.
371
372      When restoring from CudnnLSTM-trained checkpoints, must use
373      CudnnCompatibleLSTMBlockCell instead.
374    """
375    super(LSTMBlockCell, self).__init__(_reuse=reuse, dtype=dtype, name=name)
376    self._num_units = num_units
377    self._forget_bias = forget_bias
378    self._use_peephole = use_peephole
379    self._cell_clip = cell_clip if cell_clip is not None else -1
380    self._names = {
381        "W": "kernel",
382        "b": "bias",
383        "wci": "w_i_diag",
384        "wcf": "w_f_diag",
385        "wco": "w_o_diag",
386        "scope": "lstm_cell"
387    }
388    # Inputs must be 2-dimensional.
389    self.input_spec = input_spec.InputSpec(ndim=2)
390
391  @property
392  def state_size(self):
393    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
394
395  @property
396  def output_size(self):
397    return self._num_units
398
399  def build(self, inputs_shape):
400    if not inputs_shape.dims[1].value:
401      raise ValueError(
402          "Expecting inputs_shape[1] to be set: %s" % str(inputs_shape))
403    input_size = inputs_shape.dims[1].value
404    self._kernel = self.add_variable(
405        self._names["W"], [input_size + self._num_units, self._num_units * 4])
406    self._bias = self.add_variable(
407        self._names["b"], [self._num_units * 4],
408        initializer=init_ops.constant_initializer(0.0))
409    if self._use_peephole:
410      self._w_i_diag = self.add_variable(self._names["wci"], [self._num_units])
411      self._w_f_diag = self.add_variable(self._names["wcf"], [self._num_units])
412      self._w_o_diag = self.add_variable(self._names["wco"], [self._num_units])
413
414    self.built = True
415
416  def call(self, inputs, state):
417    """Long short-term memory cell (LSTM)."""
418    if len(state) != 2:
419      raise ValueError("Expecting state to be a tuple with length 2.")
420
421    if self._use_peephole:
422      wci = self._w_i_diag
423      wcf = self._w_f_diag
424      wco = self._w_o_diag
425    else:
426      wci = wcf = wco = array_ops.zeros([self._num_units], dtype=self.dtype)
427
428    (cs_prev, h_prev) = state
429    (_, cs, _, _, _, _, h) = _lstm_block_cell(
430        inputs,
431        cs_prev,
432        h_prev,
433        self._kernel,
434        self._bias,
435        wci=wci,
436        wcf=wcf,
437        wco=wco,
438        forget_bias=self._forget_bias,
439        cell_clip=self._cell_clip,
440        use_peephole=self._use_peephole)
441
442    new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
443    return h, new_state
444
445
446@six.add_metaclass(abc.ABCMeta)
447class LSTMBlockWrapper(base_layer.Layer):
448  """This is a helper class that provides housekeeping for LSTM cells.
449
450  This may be useful for alternative LSTM and similar type of cells.
451  The subclasses must implement `_call_cell` method and `num_units` property.
452  """
453
454  @abc.abstractproperty
455  def num_units(self):
456    """Number of units in this cell (output dimension)."""
457    pass
458
459  @abc.abstractmethod
460  def _call_cell(self, inputs, initial_cell_state, initial_output, dtype,
461                 sequence_length):
462    """Run this LSTM on inputs, starting from the given state.
463
464    This method must be implemented by subclasses and does the actual work
465    of calling the cell.
466
467    Args:
468      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
469      initial_cell_state: initial value for cell state, shape `[batch_size,
470        self._num_units]`
471      initial_output: initial value of cell output, shape `[batch_size,
472        self._num_units]`
473      dtype: The data type for the initial state and expected output.
474      sequence_length: Specifies the length of each sequence in inputs. An int32
475        or int64 vector (tensor) size [batch_size], values in [0, time_len) or
476          None.
477
478    Returns:
479      A pair containing:
480
481      - State: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
482      - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
483    """
484    pass
485
486  def call(self, inputs, initial_state=None, dtype=None, sequence_length=None):
487    """Run this LSTM on inputs, starting from the given state.
488
489    Args:
490      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`.
491      initial_state: a tuple `(initial_cell_state, initial_output)` with tensors
492        of shape `[batch_size, self._num_units]`. If this is not provided, the
493        cell is expected to create a zero initial state of type `dtype`.
494      dtype: The data type for the initial state and expected output. Required
495        if `initial_state` is not provided or RNN state has a heterogeneous
496        dtype.
497      sequence_length: Specifies the length of each sequence in inputs. An
498        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
499        time_len).`
500        Defaults to `time_len` for each element.
501
502    Returns:
503      A pair containing:
504
505      - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
506        or a list of time_len tensors of shape `[batch_size, output_size]`,
507        to match the type of the `inputs`.
508      - Final state: a tuple `(cell_state, output)` matching `initial_state`.
509
510    Raises:
511      ValueError: in case of shape mismatches
512    """
513    is_list = isinstance(inputs, list)
514    if is_list:
515      inputs = array_ops.stack(inputs)
516    inputs_shape = inputs.get_shape().with_rank(3)
517    if not inputs_shape[2]:
518      raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape)
519    batch_size = inputs_shape.dims[1].value
520    if batch_size is None:
521      batch_size = array_ops.shape(inputs)[1]
522    time_len = inputs_shape.dims[0].value
523    if time_len is None:
524      time_len = array_ops.shape(inputs)[0]
525
526    # Provide default values for initial_state and dtype
527    if initial_state is None:
528      if dtype is None:
529        raise ValueError("Either initial_state or dtype needs to be specified")
530      z = array_ops.zeros(
531          array_ops.stack([batch_size, self.num_units]), dtype=dtype)
532      initial_state = z, z
533    else:
534      if len(initial_state) != 2:
535        raise ValueError(
536            "Expecting initial_state to be a tuple with length 2 or None")
537      if dtype is None:
538        dtype = initial_state[0].dtype
539
540    # create the actual cell
541    if sequence_length is not None:
542      sequence_length = ops.convert_to_tensor(sequence_length)
543    initial_cell_state, initial_output = initial_state  # pylint: disable=unpacking-non-sequence
544    cell_states, outputs = self._call_cell(
545        inputs, initial_cell_state, initial_output, dtype, sequence_length)
546
547    if sequence_length is not None:
548      # Mask out the part beyond sequence_length
549      mask = array_ops.transpose(
550          array_ops.sequence_mask(sequence_length, time_len, dtype=dtype),
551          [1, 0])
552      mask = array_ops.tile(
553          array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units])
554      outputs *= mask
555      # Prepend initial states to cell_states and outputs for indexing to work
556      # correctly,since we want to access the last valid state at
557      # sequence_length - 1, which can even be -1, corresponding to the
558      # initial state.
559      mod_cell_states = array_ops.concat(
560          [array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0)
561      mod_outputs = array_ops.concat(
562          [array_ops.expand_dims(initial_output, [0]), outputs], 0)
563      final_cell_state = self._gather_states(mod_cell_states, sequence_length,
564                                             batch_size)
565      final_output = self._gather_states(mod_outputs, sequence_length,
566                                         batch_size)
567    else:
568      # No sequence_lengths used: final state is the last state
569      final_cell_state = cell_states[-1]
570      final_output = outputs[-1]
571
572    if is_list:
573      # Input was a list, so return a list
574      outputs = array_ops.unstack(outputs)
575
576    final_state = rnn_cell_impl.LSTMStateTuple(final_cell_state, final_output)
577    return outputs, final_state
578
579  def _gather_states(self, data, indices, batch_size):
580    """Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
581    return array_ops.gather_nd(
582        data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
583
584
585class LSTMBlockFusedCell(LSTMBlockWrapper):
586  """FusedRNNCell implementation of LSTM.
587
588  This is an extremely efficient LSTM implementation, that uses a single TF op
589  for the entire LSTM. It should be both faster and more memory-efficient than
590  LSTMBlockCell defined above.
591
592  The implementation is based on: http://arxiv.org/abs/1409.2329.
593
594  We add forget_bias (default: 1) to the biases of the forget gate in order to
595  reduce the scale of forgetting in the beginning of the training.
596
597  The variable naming is consistent with `rnn_cell_impl.LSTMCell`.
598  """
599
600  def __init__(self,
601               num_units,
602               forget_bias=1.0,
603               cell_clip=None,
604               use_peephole=False,
605               reuse=None,
606               dtype=None,
607               name="lstm_fused_cell"):
608    """Initialize the LSTM cell.
609
610    Args:
611      num_units: int, The number of units in the LSTM cell.
612      forget_bias: float, The bias added to forget gates (see above).
613      cell_clip: clip the cell to this value. Defaults is no cell clipping.
614      use_peephole: Whether to use peephole connections or not.
615      reuse: (optional) boolean describing whether to reuse variables in an
616        existing scope.  If not `True`, and the existing scope already has the
617        given variables, an error is raised.
618      dtype: the dtype of variables of this layer.
619      name: String, the name of the layer. Layers with the same name will
620        share weights, but to avoid mistakes we require reuse=True in such
621        cases.  By default this is "lstm_cell", for variable-name compatibility
622        with `tf.nn.rnn_cell.LSTMCell`.
623    """
624    super(LSTMBlockFusedCell, self).__init__(
625        _reuse=reuse, name=name, dtype=dtype)
626    self._num_units = num_units
627    self._forget_bias = forget_bias
628    self._cell_clip = cell_clip if cell_clip is not None else -1
629    self._use_peephole = use_peephole
630
631    # Inputs must be 3-dimensional.
632    self.input_spec = input_spec.InputSpec(ndim=3)
633
634  @property
635  def num_units(self):
636    """Number of units in this cell (output dimension)."""
637    return self._num_units
638
639  def build(self, input_shape):
640    input_size = input_shape.dims[2].value
641    self._kernel = self.add_variable(
642        "kernel", [input_size + self._num_units, self._num_units * 4])
643    self._bias = self.add_variable(
644        "bias", [self._num_units * 4],
645        initializer=init_ops.constant_initializer(0.0))
646    if self._use_peephole:
647      self._w_i_diag = self.add_variable("w_i_diag", [self._num_units])
648      self._w_f_diag = self.add_variable("w_f_diag", [self._num_units])
649      self._w_o_diag = self.add_variable("w_o_diag", [self._num_units])
650
651    self.built = True
652
653  def _call_cell(self,
654                 inputs,
655                 initial_cell_state=None,
656                 initial_output=None,
657                 dtype=None,
658                 sequence_length=None):
659    """Run this LSTM on inputs, starting from the given state.
660
661    Args:
662      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
663      initial_cell_state: initial value for cell state, shape `[batch_size,
664        self._num_units]`
665      initial_output: initial value of cell output, shape `[batch_size,
666        self._num_units]`
667      dtype: The data type for the initial state and expected output.
668      sequence_length: Specifies the length of each sequence in inputs. An
669        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
670        time_len)` or None.
671
672    Returns:
673      A pair containing:
674
675      - Cell state (cs): A `3-D` tensor of shape `[time_len, batch_size,
676                         output_size]`
677      - Output (h): A `3-D` tensor of shape `[time_len, batch_size,
678                    output_size]`
679    """
680
681    inputs_shape = inputs.get_shape().with_rank(3)
682    time_len = inputs_shape.dims[0].value
683    if time_len is None:
684      time_len = array_ops.shape(inputs)[0]
685
686    if self._use_peephole:
687      wci = self._w_i_diag
688      wco = self._w_o_diag
689      wcf = self._w_f_diag
690    else:
691      wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype)
692
693    if sequence_length is None:
694      max_seq_len = math_ops.cast(time_len, dtypes.int64)
695    else:
696      max_seq_len = math_ops.cast(math_ops.reduce_max(sequence_length),
697                                  dtypes.int64)
698
699    _, cs, _, _, _, _, h = gen_lstm_ops.block_lstm(
700        seq_len_max=max_seq_len,
701        x=inputs,
702        cs_prev=initial_cell_state,
703        h_prev=initial_output,
704        w=self._kernel,
705        wci=wci,
706        wcf=wcf,
707        wco=wco,
708        b=self._bias,
709        forget_bias=self._forget_bias,
710        cell_clip=self._cell_clip,
711        use_peephole=self._use_peephole)
712    return cs, h
713