• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RNN helpers for TensorFlow models."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import context
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import control_flow_util
29from tensorflow.python.ops import control_flow_util_v2
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import rnn_cell_impl
32from tensorflow.python.ops import tensor_array_ops
33from tensorflow.python.ops import variable_scope as vs
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import dispatch
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import tf_export
38
39# pylint: disable=protected-access
40_concat = rnn_cell_impl._concat
41# pylint: enable=protected-access
42
43
44def _transpose_batch_time(x):
45  """Transposes the batch and time dimensions of a Tensor.
46
47  If the input tensor has rank < 2 it returns the original tensor. Retains as
48  much of the static shape information as possible.
49
50  Args:
51    x: A Tensor.
52
53  Returns:
54    x transposed along the first two dimensions.
55  """
56  x_static_shape = x.get_shape()
57  if x_static_shape.rank is not None and x_static_shape.rank < 2:
58    return x
59
60  x_rank = array_ops.rank(x)
61  x_t = array_ops.transpose(
62      x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0))
63  x_t.set_shape(
64      tensor_shape.TensorShape(
65          [x_static_shape.dims[1].value,
66           x_static_shape.dims[0].value]).concatenate(x_static_shape[2:]))
67  return x_t
68
69
70def _best_effort_input_batch_size(flat_input):
71  """Get static input batch size if available, with fallback to the dynamic one.
72
73  Args:
74    flat_input: An iterable of time major input Tensors of shape `[max_time,
75      batch_size, ...]`. All inputs should have compatible batch sizes.
76
77  Returns:
78    The batch size in Python integer if available, or a scalar Tensor otherwise.
79
80  Raises:
81    ValueError: if there is any input with an invalid shape.
82  """
83  for input_ in flat_input:
84    shape = input_.shape
85    if shape.rank is None:
86      continue
87    if shape.rank < 2:
88      raise ValueError("Expected input tensor %s to have rank at least 2" %
89                       input_)
90    batch_size = shape.dims[1].value
91    if batch_size is not None:
92      return batch_size
93  # Fallback to the dynamic batch size of the first input.
94  return array_ops.shape(flat_input[0])[1]
95
96
97def _infer_state_dtype(explicit_dtype, state):
98  """Infer the dtype of an RNN state.
99
100  Args:
101    explicit_dtype: explicitly declared dtype or None.
102    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
103      Tensors.
104
105  Returns:
106    dtype: inferred dtype of hidden state.
107
108  Raises:
109    ValueError: if `state` has heterogeneous dtypes or is empty.
110  """
111  if explicit_dtype is not None:
112    return explicit_dtype
113  elif nest.is_sequence(state):
114    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
115    if not inferred_dtypes:
116      raise ValueError("Unable to infer dtype from empty state.")
117    all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes)
118    if not all_same:
119      raise ValueError(
120          "State has tensors of different inferred_dtypes. Unable to infer a "
121          "single representative dtype.")
122    return inferred_dtypes[0]
123  else:
124    return state.dtype
125
126
127def _maybe_tensor_shape_from_tensor(shape):
128  if isinstance(shape, ops.Tensor):
129    return tensor_shape.as_shape(tensor_util.constant_value(shape))
130  else:
131    return shape
132
133
134def _should_cache():
135  """Returns True if a default caching device should be set, otherwise False."""
136  if context.executing_eagerly():
137    return False
138  # Don't set a caching device when running in a loop, since it is possible that
139  # train steps could be wrapped in a tf.while_loop. In that scenario caching
140  # prevents forward computations in loop iterations from re-reading the
141  # updated weights.
142  graph = ops.get_default_graph()
143  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
144  in_v1_while_loop = (
145      control_flow_util.GetContainingWhileContext(ctxt) is not None)
146  in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph)
147  return not in_v1_while_loop and not in_v2_while_loop
148
149
150# pylint: disable=unused-argument
151def _rnn_step(time,
152              sequence_length,
153              min_sequence_length,
154              max_sequence_length,
155              zero_output,
156              state,
157              call_cell,
158              state_size,
159              skip_conditionals=False):
160  """Calculate one step of a dynamic RNN minibatch.
161
162  Returns an (output, state) pair conditioned on `sequence_length`.
163  When skip_conditionals=False, the pseudocode is something like:
164
165  if t >= max_sequence_length:
166    return (zero_output, state)
167  if t < min_sequence_length:
168    return call_cell()
169
170  # Selectively output zeros or output, old state or new state depending
171  # on whether we've finished calculating each row.
172  new_output, new_state = call_cell()
173  final_output = np.vstack([
174    zero_output if time >= sequence_length[r] else new_output_r
175    for r, new_output_r in enumerate(new_output)
176  ])
177  final_state = np.vstack([
178    state[r] if time >= sequence_length[r] else new_state_r
179    for r, new_state_r in enumerate(new_state)
180  ])
181  return (final_output, final_state)
182
183  Args:
184    time: int32 `Tensor` scalar.
185    sequence_length: int32 `Tensor` vector of size [batch_size].
186    min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
187    max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
188    zero_output: `Tensor` vector of shape [output_size].
189    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
190      or a list/tuple of such tensors.
191    call_cell: lambda returning tuple of (new_output, new_state) where
192      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
193      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
194    state_size: The `cell.state_size` associated with the state.
195    skip_conditionals: Python bool, whether to skip using the conditional
196      calculations.  This is useful for `dynamic_rnn`, where the input tensor
197      matches `max_sequence_length`, and using conditionals just slows
198      everything down.
199
200  Returns:
201    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
202      final_output is a `Tensor` matrix of shape [batch_size, output_size]
203      final_state is either a single `Tensor` matrix, or a tuple of such
204        matrices (matching length and shapes of input `state`).
205
206  Raises:
207    ValueError: If the cell returns a state tuple whose length does not match
208      that returned by `state_size`.
209  """
210
211  # Convert state to a list for ease of use
212  flat_state = nest.flatten(state)
213  flat_zero_output = nest.flatten(zero_output)
214
215  # Vector describing which batch entries are finished.
216  copy_cond = time >= sequence_length
217
218  def _copy_one_through(output, new_output):
219    # TensorArray and scalar get passed through.
220    if isinstance(output, tensor_array_ops.TensorArray):
221      return new_output
222    if output.shape.rank == 0:
223      return new_output
224    # Otherwise propagate the old or the new value.
225    with ops.colocate_with(new_output):
226      return array_ops.where(copy_cond, output, new_output)
227
228  def _copy_some_through(flat_new_output, flat_new_state):
229    # Use broadcasting select to determine which values should get
230    # the previous state & zero output, and which values should get
231    # a calculated state & output.
232    flat_new_output = [
233        _copy_one_through(zero_output, new_output)
234        for zero_output, new_output in zip(flat_zero_output, flat_new_output)
235    ]
236    flat_new_state = [
237        _copy_one_through(state, new_state)
238        for state, new_state in zip(flat_state, flat_new_state)
239    ]
240    return flat_new_output + flat_new_state
241
242  def _maybe_copy_some_through():
243    """Run RNN step.  Pass through either no or some past state."""
244    new_output, new_state = call_cell()
245
246    nest.assert_same_structure(zero_output, new_output)
247    nest.assert_same_structure(state, new_state)
248
249    flat_new_state = nest.flatten(new_state)
250    flat_new_output = nest.flatten(new_output)
251    return control_flow_ops.cond(
252        # if t < min_seq_len: calculate and return everything
253        time < min_sequence_length,
254        lambda: flat_new_output + flat_new_state,
255        # else copy some of it through
256        lambda: _copy_some_through(flat_new_output, flat_new_state))
257
258  # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
259  # but benefits from removing cond() and its gradient.  We should
260  # profile with and without this switch here.
261  if skip_conditionals:
262    # Instead of using conditionals, perform the selective copy at all time
263    # steps.  This is faster when max_seq_len is equal to the number of unrolls
264    # (which is typical for dynamic_rnn).
265    new_output, new_state = call_cell()
266    nest.assert_same_structure(zero_output, new_output)
267    nest.assert_same_structure(state, new_state)
268    new_state = nest.flatten(new_state)
269    new_output = nest.flatten(new_output)
270    final_output_and_state = _copy_some_through(new_output, new_state)
271  else:
272    empty_update = lambda: flat_zero_output + flat_state
273    final_output_and_state = control_flow_ops.cond(
274        # if t >= max_seq_len: copy all state through, output zeros
275        time >= max_sequence_length,
276        empty_update,
277        # otherwise calculation is required: copy some or all of it through
278        _maybe_copy_some_through)
279
280  if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
281    raise ValueError("Internal error: state and output were not concatenated "
282                     "correctly.")
283  final_output = final_output_and_state[:len(flat_zero_output)]
284  final_state = final_output_and_state[len(flat_zero_output):]
285
286  for output, flat_output in zip(final_output, flat_zero_output):
287    output.set_shape(flat_output.get_shape())
288  for substate, flat_substate in zip(final_state, flat_state):
289    if not isinstance(substate, tensor_array_ops.TensorArray):
290      substate.set_shape(flat_substate.get_shape())
291
292  final_output = nest.pack_sequence_as(
293      structure=zero_output, flat_sequence=final_output)
294  final_state = nest.pack_sequence_as(
295      structure=state, flat_sequence=final_state)
296
297  return final_output, final_state
298
299
300def _reverse_seq(input_seq, lengths):
301  """Reverse a list of Tensors up to specified lengths.
302
303  Args:
304    input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
305      or nested tuples of tensors.
306    lengths:   A `Tensor` of dimension batch_size, containing lengths for each
307      sequence in the batch. If "None" is specified, simply reverses the list.
308
309  Returns:
310    time-reversed sequence
311  """
312  if lengths is None:
313    return list(reversed(input_seq))
314
315  flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)
316
317  flat_results = [[] for _ in range(len(input_seq))]
318  for sequence in zip(*flat_input_seq):
319    input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank)
320    for input_ in sequence:
321      input_shape.assert_is_compatible_with(input_.get_shape())
322      input_.set_shape(input_shape)
323
324    # Join into (time, batch_size, depth)
325    s_joined = array_ops.stack(sequence)
326
327    # Reverse along dimension 0
328    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
329    # Split again into list
330    result = array_ops.unstack(s_reversed)
331    for r, flat_result in zip(result, flat_results):
332      r.set_shape(input_shape)
333      flat_result.append(r)
334
335  results = [
336      nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
337      for input_, flat_result in zip(input_seq, flat_results)
338  ]
339  return results
340
341
342@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
343                        "keras.layers.RNN(cell))`, which is equivalent to "
344                        "this API")
345@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
346@dispatch.add_dispatch_support
347def bidirectional_dynamic_rnn(cell_fw,
348                              cell_bw,
349                              inputs,
350                              sequence_length=None,
351                              initial_state_fw=None,
352                              initial_state_bw=None,
353                              dtype=None,
354                              parallel_iterations=None,
355                              swap_memory=False,
356                              time_major=False,
357                              scope=None):
358  """Creates a dynamic version of bidirectional recurrent neural network.
359
360  Takes input and builds independent forward and backward RNNs. The input_size
361  of forward and backward cell must match. The initial state for both directions
362  is zero by default (but can be set optionally) and no intermediate states are
363  ever returned -- the network is fully unrolled for the given (passed in)
364  length(s) of the sequence(s) or completely unrolled if length(s) is not
365  given.
366
367  Args:
368    cell_fw: An instance of RNNCell, to be used for forward direction.
369    cell_bw: An instance of RNNCell, to be used for backward direction.
370    inputs: The RNN inputs.
371      If time_major == False (default), this must be a tensor of shape:
372        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
373      If time_major == True, this must be a tensor of shape: `[max_time,
374        batch_size, ...]`, or a nested tuple of such elements.
375    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
376      containing the actual lengths for each of the sequences in the batch. If
377      not provided, all batch entries are assumed to be full sequences; and time
378      reversal is applied from time `0` to `max_time` for each sequence.
379    initial_state_fw: (optional) An initial state for the forward RNN. This must
380      be a tensor of appropriate type and shape `[batch_size,
381      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
382      tuple of tensors having shapes `[batch_size, s] for s in
383      cell_fw.state_size`.
384    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
385      corresponding properties of `cell_bw`.
386    dtype: (optional) The data type for the initial states and expected output.
387      Required if initial_states are not provided or RNN states have a
388      heterogeneous dtype.
389    parallel_iterations: (Default: 32).  The number of iterations to run in
390      parallel.  Those operations which do not have any temporal dependency and
391      can be run in parallel, will be.  This parameter trades off time for
392      space.  Values >> 1 use more memory but take less time, while smaller
393      values use less memory but computations take longer.
394    swap_memory: Transparently swap the tensors produced in forward inference
395      but needed for back prop from GPU to CPU.  This allows training RNNs which
396      would typically not fit on a single GPU, with very minimal (or no)
397      performance penalty.
398    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
399      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
400      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
401      `time_major = True` is a bit more efficient because it avoids transposes
402      at the beginning and end of the RNN calculation.  However, most TensorFlow
403      data is batch-major, so by default this function accepts input and emits
404      output in batch-major form.
405    scope: VariableScope for the created subgraph; defaults to
406      "bidirectional_rnn"
407
408  Returns:
409    A tuple (outputs, output_states) where:
410      outputs: A tuple (output_fw, output_bw) containing the forward and
411        the backward rnn output `Tensor`.
412        If time_major == False (default),
413          output_fw will be a `Tensor` shaped:
414          `[batch_size, max_time, cell_fw.output_size]`
415          and output_bw will be a `Tensor` shaped:
416          `[batch_size, max_time, cell_bw.output_size]`.
417        If time_major == True,
418          output_fw will be a `Tensor` shaped:
419          `[max_time, batch_size, cell_fw.output_size]`
420          and output_bw will be a `Tensor` shaped:
421          `[max_time, batch_size, cell_bw.output_size]`.
422        It returns a tuple instead of a single concatenated `Tensor`, unlike
423        in the `bidirectional_rnn`. If the concatenated one is preferred,
424        the forward and backward outputs can be concatenated as
425        `tf.concat(outputs, 2)`.
426      output_states: A tuple (output_state_fw, output_state_bw) containing
427        the forward and the backward final states of bidirectional rnn.
428
429  Raises:
430    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
431  """
432  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
433  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
434
435  with vs.variable_scope(scope or "bidirectional_rnn"):
436    # Forward direction
437    with vs.variable_scope("fw") as fw_scope:
438      output_fw, output_state_fw = dynamic_rnn(
439          cell=cell_fw,
440          inputs=inputs,
441          sequence_length=sequence_length,
442          initial_state=initial_state_fw,
443          dtype=dtype,
444          parallel_iterations=parallel_iterations,
445          swap_memory=swap_memory,
446          time_major=time_major,
447          scope=fw_scope)
448
449    # Backward direction
450    if not time_major:
451      time_axis = 1
452      batch_axis = 0
453    else:
454      time_axis = 0
455      batch_axis = 1
456
457    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
458      if seq_lengths is not None:
459        return array_ops.reverse_sequence(
460            input=input_,
461            seq_lengths=seq_lengths,
462            seq_axis=seq_axis,
463            batch_axis=batch_axis)
464      else:
465        return array_ops.reverse(input_, axis=[seq_axis])
466
467    with vs.variable_scope("bw") as bw_scope:
468
469      def _map_reverse(inp):
470        return _reverse(
471            inp,
472            seq_lengths=sequence_length,
473            seq_axis=time_axis,
474            batch_axis=batch_axis)
475
476      inputs_reverse = nest.map_structure(_map_reverse, inputs)
477      tmp, output_state_bw = dynamic_rnn(
478          cell=cell_bw,
479          inputs=inputs_reverse,
480          sequence_length=sequence_length,
481          initial_state=initial_state_bw,
482          dtype=dtype,
483          parallel_iterations=parallel_iterations,
484          swap_memory=swap_memory,
485          time_major=time_major,
486          scope=bw_scope)
487
488  output_bw = _reverse(
489      tmp,
490      seq_lengths=sequence_length,
491      seq_axis=time_axis,
492      batch_axis=batch_axis)
493
494  outputs = (output_fw, output_bw)
495  output_states = (output_state_fw, output_state_bw)
496
497  return (outputs, output_states)
498
499
500@deprecation.deprecated(
501    None,
502    "Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
503@tf_export(v1=["nn.dynamic_rnn"])
504@dispatch.add_dispatch_support
505def dynamic_rnn(cell,
506                inputs,
507                sequence_length=None,
508                initial_state=None,
509                dtype=None,
510                parallel_iterations=None,
511                swap_memory=False,
512                time_major=False,
513                scope=None):
514  """Creates a recurrent neural network specified by RNNCell `cell`.
515
516  Performs fully dynamic unrolling of `inputs`.
517
518  Example:
519
520  ```python
521  # create a BasicRNNCell
522  rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
523
524  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
525
526  # defining initial state
527  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
528
529  # 'state' is a tensor of shape [batch_size, cell_state_size]
530  outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data,
531                                     initial_state=initial_state,
532                                     dtype=tf.float32)
533  ```
534
535  ```python
536  # create 2 LSTMCells
537  rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
538
539  # create a RNN cell composed sequentially of a number of RNNCells
540  multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
541
542  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
543  # 'state' is a N-tuple where N is the number of LSTMCells containing a
544  # tf.nn.rnn_cell.LSTMStateTuple for each cell
545  outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
546                                     inputs=data,
547                                     dtype=tf.float32)
548  ```
549
550
551  Args:
552    cell: An instance of RNNCell.
553    inputs: The RNN inputs.
554      If `time_major == False` (default), this must be a `Tensor` of shape:
555        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
556      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
557        batch_size, ...]`, or a nested tuple of such elements. This may also be
558        a (possibly nested) tuple of Tensors satisfying this property.  The
559        first two dimensions must match across all the inputs, but otherwise the
560        ranks and other shape components may differ. In this case, input to
561        `cell` at each time-step will replicate the structure of these tuples,
562        except for the time dimension (from which the time is taken). The input
563        to `cell` at each time step will be a `Tensor` or (possibly nested)
564        tuple of Tensors each with dimensions `[batch_size, ...]`.
565    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
566      to copy-through state and zero-out outputs when past a batch element's
567      sequence length.  This parameter enables users to extract the last valid
568      state and properly padded outputs, so it is provided for correctness.
569    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
570      is an integer, this must be a `Tensor` of appropriate type and shape
571      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
572      should be a tuple of tensors having shapes `[batch_size, s] for s in
573      cell.state_size`.
574    dtype: (optional) The data type for the initial state and expected output.
575      Required if initial_state is not provided or RNN state has a heterogeneous
576      dtype.
577    parallel_iterations: (Default: 32).  The number of iterations to run in
578      parallel.  Those operations which do not have any temporal dependency and
579      can be run in parallel, will be.  This parameter trades off time for
580      space.  Values >> 1 use more memory but take less time, while smaller
581      values use less memory but computations take longer.
582    swap_memory: Transparently swap the tensors produced in forward inference
583      but needed for back prop from GPU to CPU.  This allows training RNNs which
584      would typically not fit on a single GPU, with very minimal (or no)
585      performance penalty.
586    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
587      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
588      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
589      `time_major = True` is a bit more efficient because it avoids transposes
590      at the beginning and end of the RNN calculation.  However, most TensorFlow
591      data is batch-major, so by default this function accepts input and emits
592      output in batch-major form.
593    scope: VariableScope for the created subgraph; defaults to "rnn".
594
595  Returns:
596    A pair (outputs, state) where:
597
598    outputs: The RNN output `Tensor`.
599
600      If time_major == False (default), this will be a `Tensor` shaped:
601        `[batch_size, max_time, cell.output_size]`.
602
603      If time_major == True, this will be a `Tensor` shaped:
604        `[max_time, batch_size, cell.output_size]`.
605
606      Note, if `cell.output_size` is a (possibly nested) tuple of integers
607      or `TensorShape` objects, then `outputs` will be a tuple having the
608      same structure as `cell.output_size`, containing Tensors having shapes
609      corresponding to the shape data in `cell.output_size`.
610
611    state: The final state.  If `cell.state_size` is an int, this
612      will be shaped `[batch_size, cell.state_size]`.  If it is a
613      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
614      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
615      be a tuple having the corresponding shapes. If cells are `LSTMCells`
616      `state` will be a tuple containing a `LSTMStateTuple` for each cell.
617
618  Raises:
619    TypeError: If `cell` is not an instance of RNNCell.
620    ValueError: If inputs is None or an empty list.
621
622  @compatibility(TF2)
623  `tf.compat.v1.nn.dynamic_rnn` is not compatible with eager execution and
624  `tf.function`. Please use `tf.keras.layers.RNN` instead for TF2 migration.
625  Take LSTM as an example, you can instantiate a `tf.keras.layers.RNN` layer
626  with `tf.keras.layers.LSTMCell`, or directly via `tf.keras.layers.LSTM`. Once
627  the keras layer is created, you can get the output and states by calling
628  the layer with input and states. Please refer to [this
629  guide](https://www.tensorflow.org/guide/keras/rnn) for more details about
630  Keras RNN. You can also find more details about the difference and comparison
631  between Keras RNN and TF compat v1 rnn in [this
632  document](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md)
633
634  #### Structural Mapping to Native TF2
635
636  Before:
637
638  ```python
639  # create 2 LSTMCells
640  rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
641
642  # create a RNN cell composed sequentially of a number of RNNCells
643  multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
644
645  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
646  # 'state' is a N-tuple where N is the number of LSTMCells containing a
647  # tf.nn.rnn_cell.LSTMStateTuple for each cell
648  outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
649                                               inputs=data,
650                                               dtype=tf.float32)
651  ```
652
653  After:
654
655  ```python
656  # RNN layer can take a list of cells, which will then stack them together.
657  # By default, keras RNN will only return the last timestep output and will not
658  # return states. If you need whole time sequence output as well as the states,
659  # you can set `return_sequences` and `return_state` to True.
660  rnn_layer = tf.keras.layers.RNN([tf.keras.layers.LSTMCell(128),
661                                   tf.keras.layers.LSTMCell(256)],
662                                  return_sequences=True,
663                                  return_state=True)
664  outputs, output_states = rnn_layer(inputs, states)
665  ```
666
667  #### How to Map Arguments
668
669  | TF1 Arg Name          | TF2 Arg Name    | Note                             |
670  | :-------------------- | :-------------- | :------------------------------- |
671  | `cell`                | `cell`          | In the RNN layer constructor     |
672  | `inputs`              | `inputs`        | In the RNN layer `__call__`      |
673  | `sequence_length`     | Not used        | Adding masking layer before RNN  :
674  :                       :                 : to achieve the same result.      :
675  | `initial_state`       | `initial_state` | In the RNN layer `__call__`      |
676  | `dtype`               | `dtype`         | In the RNN layer constructor     |
677  | `parallel_iterations` | Not supported   |                                  |
678  | `swap_memory`         | Not supported   |                                  |
679  | `time_major`          | `time_major`    | In the RNN layer constructor     |
680  | `scope`               | Not supported   |                                  |
681  @end_compatibility
682  """
683  rnn_cell_impl.assert_like_rnncell("cell", cell)
684
685  with vs.variable_scope(scope or "rnn") as varscope:
686    # Create a new scope in which the caching device is either
687    # determined by the parent scope, or is set to place the cached
688    # Variable using the same placement as for the rest of the RNN.
689    if _should_cache():
690      if varscope.caching_device is None:
691        varscope.set_caching_device(lambda op: op.device)
692
693    # By default, time_major==False and inputs are batch-major: shaped
694    #   [batch, time, depth]
695    # For internal calculations, we transpose to [time, batch, depth]
696    flat_input = nest.flatten(inputs)
697
698    if not time_major:
699      # (B,T,D) => (T,B,D)
700      flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
701      flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
702
703    parallel_iterations = parallel_iterations or 32
704    if sequence_length is not None:
705      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
706      if sequence_length.get_shape().rank not in (None, 1):
707        raise ValueError(
708            "sequence_length must be a vector of length batch_size, "
709            "but saw shape: %s" % sequence_length.get_shape())
710      sequence_length = array_ops.identity(  # Just to find it in the graph.
711          sequence_length,
712          name="sequence_length")
713
714    batch_size = _best_effort_input_batch_size(flat_input)
715
716    if initial_state is not None:
717      state = initial_state
718    else:
719      if not dtype:
720        raise ValueError("If there is no initial_state, you must give a dtype.")
721      if getattr(cell, "get_initial_state", None) is not None:
722        state = cell.get_initial_state(
723            inputs=None, batch_size=batch_size, dtype=dtype)
724      else:
725        state = cell.zero_state(batch_size, dtype)
726
727    def _assert_has_shape(x, shape):
728      x_shape = array_ops.shape(x)
729      packed_shape = array_ops.stack(shape)
730      return control_flow_ops.Assert(
731          math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
732              "Expected shape for Tensor %s is " % x.name, packed_shape,
733              " but saw shape: ", x_shape
734          ])
735
736    if not context.executing_eagerly() and sequence_length is not None:
737      # Perform some shape validation
738      with ops.control_dependencies(
739          [_assert_has_shape(sequence_length, [batch_size])]):
740        sequence_length = array_ops.identity(
741            sequence_length, name="CheckSeqLen")
742
743    inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
744
745    (outputs, final_state) = _dynamic_rnn_loop(
746        cell,
747        inputs,
748        state,
749        parallel_iterations=parallel_iterations,
750        swap_memory=swap_memory,
751        sequence_length=sequence_length,
752        dtype=dtype)
753
754    # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
755    # If we are performing batch-major calculations, transpose output back
756    # to shape [batch, time, depth]
757    if not time_major:
758      # (T,B,D) => (B,T,D)
759      outputs = nest.map_structure(_transpose_batch_time, outputs)
760
761    return (outputs, final_state)
762
763
764def _dynamic_rnn_loop(cell,
765                      inputs,
766                      initial_state,
767                      parallel_iterations,
768                      swap_memory,
769                      sequence_length=None,
770                      dtype=None):
771  """Internal implementation of Dynamic RNN.
772
773  Args:
774    cell: An instance of RNNCell.
775    inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
776      tuple of such elements.
777    initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
778      `cell.state_size` is a tuple, then this should be a tuple of tensors
779      having shapes `[batch_size, s] for s in cell.state_size`.
780    parallel_iterations: Positive Python int.
781    swap_memory: A Python boolean
782    sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
783    dtype: (optional) Expected dtype of output. If not specified, inferred from
784      initial_state.
785
786  Returns:
787    Tuple `(final_outputs, final_state)`.
788    final_outputs:
789      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
790      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
791      objects, then this returns a (possibly nested) tuple of Tensors matching
792      the corresponding shapes.
793    final_state:
794      A `Tensor`, or possibly nested tuple of Tensors, matching in length
795      and shapes to `initial_state`.
796
797  Raises:
798    ValueError: If the input depth cannot be inferred via shape inference
799      from the inputs.
800    ValueError: If time_step is not the same for all the elements in the
801      inputs.
802    ValueError: If batch_size is not the same for all the elements in the
803      inputs.
804  """
805  state = initial_state
806  assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
807
808  state_size = cell.state_size
809
810  flat_input = nest.flatten(inputs)
811  flat_output_size = nest.flatten(cell.output_size)
812
813  # Construct an initial output
814  input_shape = array_ops.shape(flat_input[0])
815  time_steps = input_shape[0]
816  batch_size = _best_effort_input_batch_size(flat_input)
817
818  inputs_got_shape = tuple(
819      input_.get_shape().with_rank_at_least(3) for input_ in flat_input)
820
821  const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
822
823  for shape in inputs_got_shape:
824    if not shape[2:].is_fully_defined():
825      raise ValueError(
826          "Input size (depth of inputs) must be accessible via shape inference,"
827          " but saw value None.")
828    got_time_steps = shape.dims[0].value
829    got_batch_size = shape.dims[1].value
830    if const_time_steps != got_time_steps:
831      raise ValueError(
832          "Time steps is not the same for all the elements in the input in a "
833          "batch.")
834    if const_batch_size != got_batch_size:
835      raise ValueError(
836          "Batch_size is not the same for all the elements in the input.")
837
838  # Prepare dynamic conditional copying of state & output
839  def _create_zero_arrays(size):
840    size = _concat(batch_size, size)
841    return array_ops.zeros(
842        array_ops.stack(size), _infer_state_dtype(dtype, state))
843
844  flat_zero_output = tuple(
845      _create_zero_arrays(output) for output in flat_output_size)
846  zero_output = nest.pack_sequence_as(
847      structure=cell.output_size, flat_sequence=flat_zero_output)
848
849  if sequence_length is not None:
850    min_sequence_length = math_ops.reduce_min(sequence_length)
851    max_sequence_length = math_ops.reduce_max(sequence_length)
852  else:
853    max_sequence_length = time_steps
854
855  time = array_ops.constant(0, dtype=dtypes.int32, name="time")
856
857  with ops.name_scope("dynamic_rnn") as scope:
858    base_name = scope
859
860  def _create_ta(name, element_shape, dtype):
861    return tensor_array_ops.TensorArray(
862        dtype=dtype,
863        size=time_steps,
864        element_shape=element_shape,
865        tensor_array_name=base_name + name)
866
867  in_graph_mode = not context.executing_eagerly()
868  if in_graph_mode:
869    output_ta = tuple(
870        _create_ta(
871            "output_%d" % i,
872            element_shape=(
873                tensor_shape.TensorShape([const_batch_size]).concatenate(
874                    _maybe_tensor_shape_from_tensor(out_size))),
875            dtype=_infer_state_dtype(dtype, state))
876        for i, out_size in enumerate(flat_output_size))
877    input_ta = tuple(
878        _create_ta(
879            "input_%d" % i,
880            element_shape=flat_input_i.shape[1:],
881            dtype=flat_input_i.dtype)
882        for i, flat_input_i in enumerate(flat_input))
883    input_ta = tuple(
884        ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input))
885  else:
886    output_ta = tuple([0 for _ in range(time_steps.numpy())]
887                      for i in range(len(flat_output_size)))
888    input_ta = flat_input
889
890  def _time_step(time, output_ta_t, state):
891    """Take a time step of the dynamic RNN.
892
893    Args:
894      time: int32 scalar Tensor.
895      output_ta_t: List of `TensorArray`s that represent the output.
896      state: nested tuple of vector tensors that represent the state.
897
898    Returns:
899      The tuple (time + 1, output_ta_t with updated flow, new_state).
900    """
901
902    if in_graph_mode:
903      input_t = tuple(ta.read(time) for ta in input_ta)
904      # Restore some shape information
905      for input_, shape in zip(input_t, inputs_got_shape):
906        input_.set_shape(shape[1:])
907    else:
908      input_t = tuple(ta[time.numpy()] for ta in input_ta)
909
910    input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
911    # Keras RNN cells only accept state as list, even if it's a single tensor.
912    call_cell = lambda: cell(input_t, state)
913
914    if sequence_length is not None:
915      (output, new_state) = _rnn_step(
916          time=time,
917          sequence_length=sequence_length,
918          min_sequence_length=min_sequence_length,
919          max_sequence_length=max_sequence_length,
920          zero_output=zero_output,
921          state=state,
922          call_cell=call_cell,
923          state_size=state_size,
924          skip_conditionals=True)
925    else:
926      (output, new_state) = call_cell()
927
928    # Pack state if using state tuples
929    output = nest.flatten(output)
930
931    if in_graph_mode:
932      output_ta_t = tuple(
933          ta.write(time, out) for ta, out in zip(output_ta_t, output))
934    else:
935      for ta, out in zip(output_ta_t, output):
936        ta[time.numpy()] = out
937
938    return (time + 1, output_ta_t, new_state)
939
940  if in_graph_mode:
941    # Make sure that we run at least 1 step, if necessary, to ensure
942    # the TensorArrays pick up the dynamic shape.
943    loop_bound = math_ops.minimum(time_steps,
944                                  math_ops.maximum(1, max_sequence_length))
945  else:
946    # Using max_sequence_length isn't currently supported in the Eager branch.
947    loop_bound = time_steps
948
949  _, output_final_ta, final_state = control_flow_ops.while_loop(
950      cond=lambda time, *_: time < loop_bound,
951      body=_time_step,
952      loop_vars=(time, output_ta, state),
953      parallel_iterations=parallel_iterations,
954      maximum_iterations=time_steps,
955      swap_memory=swap_memory)
956
957  # Unpack final output if not using output tuples.
958  if in_graph_mode:
959    final_outputs = tuple(ta.stack() for ta in output_final_ta)
960    # Restore some shape information
961    for output, output_size in zip(final_outputs, flat_output_size):
962      shape = _concat([const_time_steps, const_batch_size],
963                      output_size,
964                      static=True)
965      output.set_shape(shape)
966  else:
967    final_outputs = output_final_ta
968
969  final_outputs = nest.pack_sequence_as(
970      structure=cell.output_size, flat_sequence=final_outputs)
971  if not in_graph_mode:
972    final_outputs = nest.map_structure_up_to(
973        cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
974
975  return (final_outputs, final_state)
976
977
978@tf_export(v1=["nn.raw_rnn"])
979@dispatch.add_dispatch_support
980def raw_rnn(cell,
981            loop_fn,
982            parallel_iterations=None,
983            swap_memory=False,
984            scope=None):
985  """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
986
987  **NOTE: This method is still in testing, and the API may change.**
988
989  This function is a more primitive version of `dynamic_rnn` that provides
990  more direct access to the inputs each iteration.  It also provides more
991  control over when to start and finish reading the sequence, and
992  what to emit for the output.
993
994  For example, it can be used to implement the dynamic decoder of a seq2seq
995  model.
996
997  Instead of working with `Tensor` objects, most operations work with
998  `TensorArray` objects directly.
999
1000  The operation of `raw_rnn`, in pseudo-code, is basically the following:
1001
1002  ```python
1003  time = tf.constant(0, dtype=tf.int32)
1004  (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
1005      time=time, cell_output=None, cell_state=None, loop_state=None)
1006  emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
1007  state = initial_state
1008  while not all(finished):
1009    (output, cell_state) = cell(next_input, state)
1010    (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
1011        time=time + 1, cell_output=output, cell_state=cell_state,
1012        loop_state=loop_state)
1013    # Emit zeros and copy forward state for minibatch entries that are finished.
1014    state = tf.where(finished, state, next_state)
1015    emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
1016    emit_ta = emit_ta.write(time, emit)
1017    # If any new minibatch entries are marked as finished, mark these.
1018    finished = tf.logical_or(finished, next_finished)
1019    time += 1
1020  return (emit_ta, state, loop_state)
1021  ```
1022
1023  with the additional properties that output and state may be (possibly nested)
1024  tuples, as determined by `cell.output_size` and `cell.state_size`, and
1025  as a result the final `state` and `emit_ta` may themselves be tuples.
1026
1027  A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:
1028
1029  ```python
1030  inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
1031                          dtype=tf.float32)
1032  sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
1033  dtype=tf.int32)
1034  inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
1035  inputs_ta = inputs_ta.unstack(inputs)
1036
1037  cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
1038
1039  def loop_fn(time, cell_output, cell_state, loop_state):
1040    emit_output = cell_output  # == None for time == 0
1041    if cell_output is None:  # time == 0
1042      next_cell_state = cell.zero_state(batch_size, tf.float32)
1043    else:
1044      next_cell_state = cell_state
1045    elements_finished = (time >= sequence_length)
1046    finished = tf.reduce_all(elements_finished)
1047    next_input = tf.cond(
1048        finished,
1049        lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
1050        lambda: inputs_ta.read(time))
1051    next_loop_state = None
1052    return (elements_finished, next_input, next_cell_state,
1053            emit_output, next_loop_state)
1054
1055  outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
1056  outputs = outputs_ta.stack()
1057  ```
1058
1059  Args:
1060    cell: An instance of RNNCell.
1061    loop_fn: A callable that takes inputs `(time, cell_output, cell_state,
1062      loop_state)` and returns the tuple `(finished, next_input,
1063      next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32
1064      scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of
1065      tensors as determined by `cell.output_size`, and `cell_state` is a
1066      `Tensor` or (possibly nested) tuple of tensors, as determined by the
1067      `loop_fn` on its first call (and should match `cell.state_size`).
1068      The outputs are: `finished`, a boolean `Tensor` of
1069      shape `[batch_size]`, `next_input`: the next input to feed to `cell`,
1070      `next_cell_state`: the next state to feed to `cell`,
1071      and `emit_output`: the output to store for this iteration.  Note that
1072        `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors
1073        which is aggregated in the `emit_ta` inside the `while_loop`. For the
1074        first call to `loop_fn`, the `emit_output` corresponds to the
1075        `emit_structure` which is then used to determine the size of the
1076        `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For
1077        the subsequent calls to the `loop_fn`, the `emit_output` corresponds to
1078        the actual output tensor that is to be aggregated in the `emit_ta`. The
1079        parameter `cell_state` and output `next_cell_state` may be either a
1080        single or (possibly nested) tuple of tensors.  The parameter
1081        `loop_state` and output `next_loop_state` may be either a single or
1082        (possibly nested) tuple of `Tensor` and `TensorArray` objects.  This
1083        last parameter may be ignored by `loop_fn` and the return value may be
1084        `None`.  If it is not `None`, then the `loop_state` will be propagated
1085        through the RNN loop, for use purely by `loop_fn` to keep track of its
1086        own state. The `next_loop_state` parameter returned may be `None`.  The
1087        first call to `loop_fn` will be `time = 0`, `cell_output = None`,
1088      `cell_state = None`, and `loop_state = None`.  For this call: The
1089        `next_cell_state` value should be the value with which to initialize the
1090        cell's state.  It may be a final state from a previous RNN or it may be
1091        the output of `cell.zero_state()`.  It should be a (possibly nested)
1092        tuple structure of tensors. If `cell.state_size` is an integer, this
1093        must be a `Tensor` of appropriate type and shape `[batch_size,
1094        cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be
1095        a `Tensor` of appropriate type and shape `[batch_size] +
1096        cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of
1097        ints or `TensorShape`, this will be a tuple having the corresponding
1098        shapes. The `emit_output` value may be either `None` or a (possibly
1099        nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0,
1100        dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first
1101        `emit_output` return value is `None`, then the `emit_ta` result of
1102        `raw_rnn` will have the same structure and dtypes as `cell.output_size`.
1103        Otherwise `emit_ta` will have the same structure, shapes (prepended with
1104        a `batch_size` dimension), and dtypes as `emit_output`.  The actual
1105        values returned for `emit_output` at this initializing call are ignored.
1106        Note, this emit structure must be consistent across all time steps.
1107    parallel_iterations: (Default: 32).  The number of iterations to run in
1108      parallel.  Those operations which do not have any temporal dependency and
1109      can be run in parallel, will be.  This parameter trades off time for
1110      space.  Values >> 1 use more memory but take less time, while smaller
1111      values use less memory but computations take longer.
1112    swap_memory: Transparently swap the tensors produced in forward inference
1113      but needed for back prop from GPU to CPU.  This allows training RNNs which
1114      would typically not fit on a single GPU, with very minimal (or no)
1115      performance penalty.
1116    scope: VariableScope for the created subgraph; defaults to "rnn".
1117
1118  Returns:
1119    A tuple `(emit_ta, final_state, final_loop_state)` where:
1120
1121    `emit_ta`: The RNN output `TensorArray`.
1122       If `loop_fn` returns a (possibly nested) set of Tensors for
1123       `emit_output` during initialization, (inputs `time = 0`,
1124       `cell_output = None`, and `loop_state = None`), then `emit_ta` will
1125       have the same structure, dtypes, and shapes as `emit_output` instead.
1126       If `loop_fn` returns `emit_output = None` during this call,
1127       the structure of `cell.output_size` is used:
1128       If `cell.output_size` is a (possibly nested) tuple of integers
1129       or `TensorShape` objects, then `emit_ta` will be a tuple having the
1130       same structure as `cell.output_size`, containing TensorArrays whose
1131       elements' shapes correspond to the shape data in `cell.output_size`.
1132
1133    `final_state`: The final cell state.  If `cell.state_size` is an int, this
1134      will be shaped `[batch_size, cell.state_size]`.  If it is a
1135      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
1136      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
1137      be a tuple having the corresponding shapes.
1138
1139    `final_loop_state`: The final loop state as returned by `loop_fn`.
1140
1141  Raises:
1142    TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
1143      a `callable`.
1144  """
1145  rnn_cell_impl.assert_like_rnncell("cell", cell)
1146
1147  if not callable(loop_fn):
1148    raise TypeError("loop_fn must be a callable")
1149
1150  parallel_iterations = parallel_iterations or 32
1151
1152  # Create a new scope in which the caching device is either
1153  # determined by the parent scope, or is set to place the cached
1154  # Variable using the same placement as for the rest of the RNN.
1155  with vs.variable_scope(scope or "rnn") as varscope:
1156    if _should_cache():
1157      if varscope.caching_device is None:
1158        varscope.set_caching_device(lambda op: op.device)
1159
1160    time = constant_op.constant(0, dtype=dtypes.int32)
1161    (elements_finished, next_input,
1162     initial_state, emit_structure, init_loop_state) = loop_fn(
1163         time, None, None, None)  # time, cell_output, cell_state, loop_state
1164    flat_input = nest.flatten(next_input)
1165
1166    # Need a surrogate loop state for the while_loop if none is available.
1167    loop_state = (
1168        init_loop_state if init_loop_state is not None else
1169        constant_op.constant(0, dtype=dtypes.int32))
1170
1171    input_shape = [input_.get_shape() for input_ in flat_input]
1172    static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0)
1173
1174    for input_shape_i in input_shape:
1175      # Static verification that batch sizes all match
1176      static_batch_size.assert_is_compatible_with(
1177          tensor_shape.dimension_at_index(input_shape_i, 0))
1178
1179    batch_size = tensor_shape.dimension_value(static_batch_size)
1180    const_batch_size = batch_size
1181    if batch_size is None:
1182      batch_size = array_ops.shape(flat_input[0])[0]
1183
1184    nest.assert_same_structure(initial_state, cell.state_size)
1185    state = initial_state
1186    flat_state = nest.flatten(state)
1187    flat_state = [ops.convert_to_tensor(s) for s in flat_state]
1188    state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state)
1189
1190    if emit_structure is not None:
1191      flat_emit_structure = nest.flatten(emit_structure)
1192      flat_emit_size = [
1193          emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit)
1194          for emit in flat_emit_structure
1195      ]
1196      flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
1197    else:
1198      emit_structure = cell.output_size
1199      flat_emit_size = nest.flatten(emit_structure)
1200      flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
1201
1202    flat_emit_ta = [
1203        tensor_array_ops.TensorArray(
1204            dtype=dtype_i,
1205            dynamic_size=True,
1206            element_shape=(tensor_shape.TensorShape([
1207                const_batch_size
1208            ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
1209            size=0,
1210            name="rnn_output_%d" % i)
1211        for i, (dtype_i,
1212                size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
1213    ]
1214    emit_ta = nest.pack_sequence_as(
1215        structure=emit_structure, flat_sequence=flat_emit_ta)
1216    flat_zero_emit = [
1217        array_ops.zeros(_concat(batch_size, size_i), dtype_i)
1218        for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
1219    ]
1220    zero_emit = nest.pack_sequence_as(
1221        structure=emit_structure, flat_sequence=flat_zero_emit)
1222
1223    def condition(unused_time, elements_finished, *_):
1224      return math_ops.logical_not(math_ops.reduce_all(elements_finished))
1225
1226    def body(time, elements_finished, current_input, emit_ta, state,
1227             loop_state):
1228      """Internal while loop body for raw_rnn.
1229
1230      Args:
1231        time: time scalar.
1232        elements_finished: batch-size vector.
1233        current_input: possibly nested tuple of input tensors.
1234        emit_ta: possibly nested tuple of output TensorArrays.
1235        state: possibly nested tuple of state tensors.
1236        loop_state: possibly nested tuple of loop state tensors.
1237
1238      Returns:
1239        Tuple having the same size as Args but with updated values.
1240      """
1241      (next_output, cell_state) = cell(current_input, state)
1242
1243      nest.assert_same_structure(state, cell_state)
1244      nest.assert_same_structure(cell.output_size, next_output)
1245
1246      next_time = time + 1
1247      (next_finished, next_input, next_state, emit_output,
1248       next_loop_state) = loop_fn(next_time, next_output, cell_state,
1249                                  loop_state)
1250
1251      nest.assert_same_structure(state, next_state)
1252      nest.assert_same_structure(current_input, next_input)
1253      nest.assert_same_structure(emit_ta, emit_output)
1254
1255      # If loop_fn returns None for next_loop_state, just reuse the
1256      # previous one.
1257      loop_state = loop_state if next_loop_state is None else next_loop_state
1258
1259      def _copy_some_through(current, candidate):
1260        """Copy some tensors through via array_ops.where."""
1261
1262        def copy_fn(cur_i, cand_i):
1263          # TensorArray and scalar get passed through.
1264          if isinstance(cur_i, tensor_array_ops.TensorArray):
1265            return cand_i
1266          if cur_i.shape.rank == 0:
1267            return cand_i
1268          # Otherwise propagate the old or the new value.
1269          with ops.colocate_with(cand_i):
1270            return array_ops.where(elements_finished, cur_i, cand_i)
1271
1272        return nest.map_structure(copy_fn, current, candidate)
1273
1274      emit_output = _copy_some_through(zero_emit, emit_output)
1275      next_state = _copy_some_through(state, next_state)
1276
1277      emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
1278                                   emit_ta, emit_output)
1279
1280      elements_finished = math_ops.logical_or(elements_finished, next_finished)
1281
1282      return (next_time, elements_finished, next_input, emit_ta, next_state,
1283              loop_state)
1284
1285    returned = control_flow_ops.while_loop(
1286        condition,
1287        body,
1288        loop_vars=[
1289            time, elements_finished, next_input, emit_ta, state, loop_state
1290        ],
1291        parallel_iterations=parallel_iterations,
1292        swap_memory=swap_memory)
1293
1294    (emit_ta, final_state, final_loop_state) = returned[-3:]
1295
1296    if init_loop_state is None:
1297      final_loop_state = None
1298
1299    return (emit_ta, final_state, final_loop_state)
1300
1301
1302@deprecation.deprecated(None,
1303                        "Please use `keras.layers.RNN(cell, unroll=True)`, "
1304                        "which is equivalent to this API")
1305@tf_export(v1=["nn.static_rnn"])
1306@dispatch.add_dispatch_support
1307def static_rnn(cell,
1308               inputs,
1309               initial_state=None,
1310               dtype=None,
1311               sequence_length=None,
1312               scope=None):
1313  """Creates a recurrent neural network specified by RNNCell `cell`.
1314
1315  The simplest form of RNN network generated is:
1316
1317  ```python
1318    state = cell.zero_state(...)
1319    outputs = []
1320    for input_ in inputs:
1321      output, state = cell(input_, state)
1322      outputs.append(output)
1323    return (outputs, state)
1324  ```
1325  However, a few other options are available:
1326
1327  An initial state can be provided.
1328  If the sequence_length vector is provided, dynamic calculation is performed.
1329  This method of calculation does not compute the RNN steps past the maximum
1330  sequence length of the minibatch (thus saving computational time),
1331  and properly propagates the state at an example's sequence length
1332  to the final state output.
1333
1334  The dynamic calculation performed is, at time `t` for batch row `b`,
1335
1336  ```python
1337    (output, state)(b, t) =
1338      (t >= sequence_length(b))
1339        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
1340        : cell(input(b, t), state(b, t - 1))
1341  ```
1342
1343  Args:
1344    cell: An instance of RNNCell.
1345    inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size,
1346      input_size]`, or a nested tuple of such elements.
1347    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
1348      is an integer, this must be a `Tensor` of appropriate type and shape
1349      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
1350      should be a tuple of tensors having shapes `[batch_size, s] for s in
1351      cell.state_size`.
1352    dtype: (optional) The data type for the initial state and expected output.
1353      Required if initial_state is not provided or RNN state has a heterogeneous
1354      dtype.
1355    sequence_length: Specifies the length of each sequence in inputs. An int32
1356      or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
1357    scope: VariableScope for the created subgraph; defaults to "rnn".
1358
1359  Returns:
1360    A pair (outputs, state) where:
1361
1362    - outputs is a length T list of outputs (one for each input), or a nested
1363      tuple of such elements.
1364    - state is the final state
1365
1366  Raises:
1367    TypeError: If `cell` is not an instance of RNNCell.
1368    ValueError: If `inputs` is `None` or an empty list, or if the input depth
1369      (column size) cannot be inferred from inputs via shape inference.
1370  """
1371  rnn_cell_impl.assert_like_rnncell("cell", cell)
1372  if not nest.is_sequence(inputs):
1373    raise TypeError("inputs must be a sequence")
1374  if not inputs:
1375    raise ValueError("inputs must not be empty")
1376
1377  outputs = []
1378  # Create a new scope in which the caching device is either
1379  # determined by the parent scope, or is set to place the cached
1380  # Variable using the same placement as for the rest of the RNN.
1381  with vs.variable_scope(scope or "rnn") as varscope:
1382    if _should_cache():
1383      if varscope.caching_device is None:
1384        varscope.set_caching_device(lambda op: op.device)
1385
1386    # Obtain the first sequence of the input
1387    first_input = inputs
1388    while nest.is_sequence(first_input):
1389      first_input = first_input[0]
1390
1391    # Temporarily avoid EmbeddingWrapper and seq2seq badness
1392    # TODO(lukaszkaiser): remove EmbeddingWrapper
1393    if first_input.get_shape().rank != 1:
1394
1395      input_shape = first_input.get_shape().with_rank_at_least(2)
1396      fixed_batch_size = input_shape.dims[0]
1397
1398      flat_inputs = nest.flatten(inputs)
1399      for flat_input in flat_inputs:
1400        input_shape = flat_input.get_shape().with_rank_at_least(2)
1401        batch_size, input_size = tensor_shape.dimension_at_index(
1402            input_shape, 0), input_shape[1:]
1403        fixed_batch_size.assert_is_compatible_with(batch_size)
1404        for i, size in enumerate(input_size.dims):
1405          if tensor_shape.dimension_value(size) is None:
1406            raise ValueError(
1407                "Input size (dimension %d of inputs) must be accessible via "
1408                "shape inference, but saw value None." % i)
1409    else:
1410      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
1411
1412    if tensor_shape.dimension_value(fixed_batch_size):
1413      batch_size = tensor_shape.dimension_value(fixed_batch_size)
1414    else:
1415      batch_size = array_ops.shape(first_input)[0]
1416    if initial_state is not None:
1417      state = initial_state
1418    else:
1419      if not dtype:
1420        raise ValueError("If no initial_state is provided, "
1421                         "dtype must be specified")
1422      if getattr(cell, "get_initial_state", None) is not None:
1423        state = cell.get_initial_state(
1424            inputs=None, batch_size=batch_size, dtype=dtype)
1425      else:
1426        state = cell.zero_state(batch_size, dtype)
1427
1428    if sequence_length is not None:  # Prepare variables
1429      sequence_length = ops.convert_to_tensor(
1430          sequence_length, name="sequence_length")
1431      if sequence_length.get_shape().rank not in (None, 1):
1432        raise ValueError(
1433            "sequence_length must be a vector of length batch_size")
1434
1435      def _create_zero_output(output_size):
1436        # convert int to TensorShape if necessary
1437        size = _concat(batch_size, output_size)
1438        output = array_ops.zeros(
1439            array_ops.stack(size), _infer_state_dtype(dtype, state))
1440        shape = _concat(
1441            tensor_shape.dimension_value(fixed_batch_size),
1442            output_size,
1443            static=True)
1444        output.set_shape(tensor_shape.TensorShape(shape))
1445        return output
1446
1447      output_size = cell.output_size
1448      flat_output_size = nest.flatten(output_size)
1449      flat_zero_output = tuple(
1450          _create_zero_output(size) for size in flat_output_size)
1451      zero_output = nest.pack_sequence_as(
1452          structure=output_size, flat_sequence=flat_zero_output)
1453
1454      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
1455      min_sequence_length = math_ops.reduce_min(sequence_length)
1456      max_sequence_length = math_ops.reduce_max(sequence_length)
1457
1458    for time, input_ in enumerate(inputs):
1459      if time > 0:
1460        varscope.reuse_variables()
1461      # pylint: disable=cell-var-from-loop
1462      call_cell = lambda: cell(input_, state)
1463      # pylint: enable=cell-var-from-loop
1464      if sequence_length is not None:
1465        (output, state) = _rnn_step(
1466            time=time,
1467            sequence_length=sequence_length,
1468            min_sequence_length=min_sequence_length,
1469            max_sequence_length=max_sequence_length,
1470            zero_output=zero_output,
1471            state=state,
1472            call_cell=call_cell,
1473            state_size=cell.state_size)
1474      else:
1475        (output, state) = call_cell()
1476      outputs.append(output)
1477
1478    return (outputs, state)
1479
1480
1481@deprecation.deprecated(None,
1482                        "Please use `keras.layers.RNN(cell, stateful=True)`, "
1483                        "which is equivalent to this API")
1484@tf_export(v1=["nn.static_state_saving_rnn"])
1485@dispatch.add_dispatch_support
1486def static_state_saving_rnn(cell,
1487                            inputs,
1488                            state_saver,
1489                            state_name,
1490                            sequence_length=None,
1491                            scope=None):
1492  """RNN that accepts a state saver for time-truncated RNN calculation.
1493
1494  Args:
1495    cell: An instance of `RNNCell`.
1496    inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size,
1497      input_size]`.
1498    state_saver: A state saver object with methods `state` and `save_state`.
1499    state_name: Python string or tuple of strings.  The name to use with the
1500      state_saver. If the cell returns tuples of states (i.e., `cell.state_size`
1501      is a tuple) then `state_name` should be a tuple of strings having the same
1502      length as `cell.state_size`.  Otherwise it should be a single string.
1503    sequence_length: (optional) An int32/int64 vector size [batch_size]. See the
1504      documentation for rnn() for more details about sequence_length.
1505    scope: VariableScope for the created subgraph; defaults to "rnn".
1506
1507  Returns:
1508    A pair (outputs, state) where:
1509      outputs is a length T list of outputs (one for each input)
1510      states is the final state
1511
1512  Raises:
1513    TypeError: If `cell` is not an instance of RNNCell.
1514    ValueError: If `inputs` is `None` or an empty list, or if the arity and
1515     type of `state_name` does not match that of `cell.state_size`.
1516  """
1517  state_size = cell.state_size
1518  state_is_tuple = nest.is_sequence(state_size)
1519  state_name_tuple = nest.is_sequence(state_name)
1520
1521  if state_is_tuple != state_name_tuple:
1522    raise ValueError("state_name should be the same type as cell.state_size.  "
1523                     "state_name: %s, cell.state_size: %s" %
1524                     (str(state_name), str(state_size)))
1525
1526  if state_is_tuple:
1527    state_name_flat = nest.flatten(state_name)
1528    state_size_flat = nest.flatten(state_size)
1529
1530    if len(state_name_flat) != len(state_size_flat):
1531      raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" %
1532                       (len(state_name_flat), len(state_size_flat)))
1533
1534    initial_state = nest.pack_sequence_as(
1535        structure=state_size,
1536        flat_sequence=[state_saver.state(s) for s in state_name_flat])
1537  else:
1538    initial_state = state_saver.state(state_name)
1539
1540  (outputs, state) = static_rnn(
1541      cell,
1542      inputs,
1543      initial_state=initial_state,
1544      sequence_length=sequence_length,
1545      scope=scope)
1546
1547  if state_is_tuple:
1548    flat_state = nest.flatten(state)
1549    state_name = nest.flatten(state_name)
1550    save_state = [
1551        state_saver.save_state(name, substate)
1552        for name, substate in zip(state_name, flat_state)
1553    ]
1554  else:
1555    save_state = [state_saver.save_state(state_name, state)]
1556
1557  with ops.control_dependencies(save_state):
1558    last_output = outputs[-1]
1559    flat_last_output = nest.flatten(last_output)
1560    flat_last_output = [
1561        array_ops.identity(output) for output in flat_last_output
1562    ]
1563    outputs[-1] = nest.pack_sequence_as(
1564        structure=last_output, flat_sequence=flat_last_output)
1565
1566    if state_is_tuple:
1567      state = nest.pack_sequence_as(
1568          structure=state,
1569          flat_sequence=[array_ops.identity(s) for s in flat_state])
1570    else:
1571      state = array_ops.identity(state)
1572
1573  return (outputs, state)
1574
1575
1576@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
1577                        "keras.layers.RNN(cell, unroll=True))`, which is "
1578                        "equivalent to this API")
1579@tf_export(v1=["nn.static_bidirectional_rnn"])
1580@dispatch.add_dispatch_support
1581def static_bidirectional_rnn(cell_fw,
1582                             cell_bw,
1583                             inputs,
1584                             initial_state_fw=None,
1585                             initial_state_bw=None,
1586                             dtype=None,
1587                             sequence_length=None,
1588                             scope=None):
1589  """Creates a bidirectional recurrent neural network.
1590
1591  Similar to the unidirectional case above (rnn) but takes input and builds
1592  independent forward and backward RNNs with the final forward and backward
1593  outputs depth-concatenated, such that the output will have the format
1594  [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
1595  forward and backward cell must match. The initial state for both directions
1596  is zero by default (but can be set optionally) and no intermediate states are
1597  ever returned -- the network is fully unrolled for the given (passed in)
1598  length(s) of the sequence(s) or completely unrolled if length(s) is not given.
1599
1600  Args:
1601    cell_fw: An instance of RNNCell, to be used for forward direction.
1602    cell_bw: An instance of RNNCell, to be used for backward direction.
1603    inputs: A length T list of inputs, each a tensor of shape [batch_size,
1604      input_size], or a nested tuple of such elements.
1605    initial_state_fw: (optional) An initial state for the forward RNN. This must
1606      be a tensor of appropriate type and shape `[batch_size,
1607      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
1608      tuple of tensors having shapes `[batch_size, s] for s in
1609      cell_fw.state_size`.
1610    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
1611      corresponding properties of `cell_bw`.
1612    dtype: (optional) The data type for the initial state.  Required if either
1613      of the initial states are not provided.
1614    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
1615      containing the actual lengths for each of the sequences.
1616    scope: VariableScope for the created subgraph; defaults to
1617      "bidirectional_rnn"
1618
1619  Returns:
1620    A tuple (outputs, output_state_fw, output_state_bw) where:
1621      outputs is a length `T` list of outputs (one for each input), which
1622        are depth-concatenated forward and backward outputs.
1623      output_state_fw is the final state of the forward rnn.
1624      output_state_bw is the final state of the backward rnn.
1625
1626  Raises:
1627    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
1628    ValueError: If inputs is None or an empty list.
1629  """
1630  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
1631  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
1632  if not nest.is_sequence(inputs):
1633    raise TypeError("inputs must be a sequence")
1634  if not inputs:
1635    raise ValueError("inputs must not be empty")
1636
1637  with vs.variable_scope(scope or "bidirectional_rnn"):
1638    # Forward direction
1639    with vs.variable_scope("fw") as fw_scope:
1640      output_fw, output_state_fw = static_rnn(
1641          cell_fw,
1642          inputs,
1643          initial_state_fw,
1644          dtype,
1645          sequence_length,
1646          scope=fw_scope)
1647
1648    # Backward direction
1649    with vs.variable_scope("bw") as bw_scope:
1650      reversed_inputs = _reverse_seq(inputs, sequence_length)
1651      tmp, output_state_bw = static_rnn(
1652          cell_bw,
1653          reversed_inputs,
1654          initial_state_bw,
1655          dtype,
1656          sequence_length,
1657          scope=bw_scope)
1658
1659  output_bw = _reverse_seq(tmp, sequence_length)
1660  # Concat each of the forward/backward outputs
1661  flat_output_fw = nest.flatten(output_fw)
1662  flat_output_bw = nest.flatten(output_bw)
1663
1664  flat_outputs = tuple(
1665      array_ops.concat([fw, bw], 1)
1666      for fw, bw in zip(flat_output_fw, flat_output_bw))
1667
1668  outputs = nest.pack_sequence_as(
1669      structure=output_fw, flat_sequence=flat_outputs)
1670
1671  return (outputs, output_state_fw, output_state_bw)
1672