• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RNN helpers for TensorFlow models."""
16from tensorflow.python.eager import context
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import control_flow_util
25from tensorflow.python.ops import control_flow_util_v2
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import rnn_cell_impl
28from tensorflow.python.ops import tensor_array_ops
29from tensorflow.python.ops import variable_scope as vs
30from tensorflow.python.util import deprecation
31from tensorflow.python.util import dispatch
32from tensorflow.python.util import nest
33from tensorflow.python.util.tf_export import tf_export
34
35# pylint: disable=protected-access
36_concat = rnn_cell_impl._concat
37# pylint: enable=protected-access
38
39
40def _transpose_batch_time(x):
41  """Transposes the batch and time dimensions of a Tensor.
42
43  If the input tensor has rank < 2 it returns the original tensor. Retains as
44  much of the static shape information as possible.
45
46  Args:
47    x: A Tensor.
48
49  Returns:
50    x transposed along the first two dimensions.
51  """
52  x_static_shape = x.get_shape()
53  if x_static_shape.rank is not None and x_static_shape.rank < 2:
54    return x
55
56  x_rank = array_ops.rank(x)
57  x_t = array_ops.transpose(
58      x, array_ops.concat(([1, 0], math_ops.range(2, x_rank)), axis=0))
59  x_t.set_shape(
60      tensor_shape.TensorShape(
61          [x_static_shape.dims[1].value,
62           x_static_shape.dims[0].value]).concatenate(x_static_shape[2:]))
63  return x_t
64
65
66def _best_effort_input_batch_size(flat_input):
67  """Get static input batch size if available, with fallback to the dynamic one.
68
69  Args:
70    flat_input: An iterable of time major input Tensors of shape `[max_time,
71      batch_size, ...]`. All inputs should have compatible batch sizes.
72
73  Returns:
74    The batch size in Python integer if available, or a scalar Tensor otherwise.
75
76  Raises:
77    ValueError: if there is any input with an invalid shape.
78  """
79  for input_ in flat_input:
80    shape = input_.shape
81    if shape.rank is None:
82      continue
83    if shape.rank < 2:
84      raise ValueError("Input tensor should have rank >= 2. Received input="
85                       f"{input_} of rank {shape.rank}")
86    batch_size = shape.dims[1].value
87    if batch_size is not None:
88      return batch_size
89  # Fallback to the dynamic batch size of the first input.
90  return array_ops.shape(flat_input[0])[1]
91
92
93def _infer_state_dtype(explicit_dtype, state):
94  """Infer the dtype of an RNN state.
95
96  Args:
97    explicit_dtype: explicitly declared dtype or None.
98    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
99      Tensors.
100
101  Returns:
102    dtype: inferred dtype of hidden state.
103
104  Raises:
105    ValueError: if `state` has heterogeneous dtypes or is empty.
106  """
107  if explicit_dtype is not None:
108    return explicit_dtype
109  elif nest.is_nested(state):
110    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
111    if not inferred_dtypes:
112      raise ValueError(f"Unable to infer dtype from argument state={state}.")
113    all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes)
114    if not all_same:
115      raise ValueError(
116          f"Argument state={state} has tensors of different inferred dtypes. "
117          "Unable to infer a single representative dtype. Dtypes received: "
118          f"{inferred_dtypes}")
119    return inferred_dtypes[0]
120  else:
121    return state.dtype
122
123
124def _maybe_tensor_shape_from_tensor(shape):
125  if isinstance(shape, ops.Tensor):
126    return tensor_shape.as_shape(tensor_util.constant_value(shape))
127  else:
128    return shape
129
130
131def _should_cache():
132  """Returns True if a default caching device should be set, otherwise False."""
133  if context.executing_eagerly():
134    return False
135  # Don't set a caching device when running in a loop, since it is possible that
136  # train steps could be wrapped in a tf.while_loop. In that scenario caching
137  # prevents forward computations in loop iterations from re-reading the
138  # updated weights.
139  graph = ops.get_default_graph()
140  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
141  in_v1_while_loop = (
142      control_flow_util.GetContainingWhileContext(ctxt) is not None)
143  in_v2_while_loop = control_flow_util_v2.in_while_loop_defun(graph)
144  return not in_v1_while_loop and not in_v2_while_loop
145
146
147# pylint: disable=unused-argument
148def _rnn_step(time,
149              sequence_length,
150              min_sequence_length,
151              max_sequence_length,
152              zero_output,
153              state,
154              call_cell,
155              state_size,
156              skip_conditionals=False):
157  """Calculate one step of a dynamic RNN minibatch.
158
159  Returns an (output, state) pair conditioned on `sequence_length`.
160  When skip_conditionals=False, the pseudocode is something like:
161
162  if t >= max_sequence_length:
163    return (zero_output, state)
164  if t < min_sequence_length:
165    return call_cell()
166
167  # Selectively output zeros or output, old state or new state depending
168  # on whether we've finished calculating each row.
169  new_output, new_state = call_cell()
170  final_output = np.vstack([
171    zero_output if time >= sequence_length[r] else new_output_r
172    for r, new_output_r in enumerate(new_output)
173  ])
174  final_state = np.vstack([
175    state[r] if time >= sequence_length[r] else new_state_r
176    for r, new_state_r in enumerate(new_state)
177  ])
178  return (final_output, final_state)
179
180  Args:
181    time: int32 `Tensor` scalar.
182    sequence_length: int32 `Tensor` vector of size [batch_size].
183    min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
184    max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
185    zero_output: `Tensor` vector of shape [output_size].
186    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
187      or a list/tuple of such tensors.
188    call_cell: lambda returning tuple of (new_output, new_state) where
189      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
190      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
191    state_size: The `cell.state_size` associated with the state.
192    skip_conditionals: Python bool, whether to skip using the conditional
193      calculations.  This is useful for `dynamic_rnn`, where the input tensor
194      matches `max_sequence_length`, and using conditionals just slows
195      everything down.
196
197  Returns:
198    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
199      final_output is a `Tensor` matrix of shape [batch_size, output_size]
200      final_state is either a single `Tensor` matrix, or a tuple of such
201        matrices (matching length and shapes of input `state`).
202
203  Raises:
204    ValueError: If the cell returns a state tuple whose length does not match
205      that returned by `state_size`.
206  """
207
208  # Convert state to a list for ease of use
209  flat_state = nest.flatten(state)
210  flat_zero_output = nest.flatten(zero_output)
211
212  # Vector describing which batch entries are finished.
213  copy_cond = time >= sequence_length
214
215  def _copy_one_through(output, new_output):
216    # TensorArray and scalar get passed through.
217    if isinstance(output, tensor_array_ops.TensorArray):
218      return new_output
219    if output.shape.rank == 0:
220      return new_output
221    # Otherwise propagate the old or the new value.
222    with ops.colocate_with(new_output):
223      return array_ops.where(copy_cond, output, new_output)
224
225  def _copy_some_through(flat_new_output, flat_new_state):
226    # Use broadcasting select to determine which values should get
227    # the previous state & zero output, and which values should get
228    # a calculated state & output.
229    flat_new_output = [
230        _copy_one_through(zero_output, new_output)
231        for zero_output, new_output in zip(flat_zero_output, flat_new_output)
232    ]
233    flat_new_state = [
234        _copy_one_through(state, new_state)
235        for state, new_state in zip(flat_state, flat_new_state)
236    ]
237    return flat_new_output + flat_new_state
238
239  def _maybe_copy_some_through():
240    """Run RNN step.  Pass through either no or some past state."""
241    new_output, new_state = call_cell()
242
243    nest.assert_same_structure(zero_output, new_output)
244    nest.assert_same_structure(state, new_state)
245
246    flat_new_state = nest.flatten(new_state)
247    flat_new_output = nest.flatten(new_output)
248    return control_flow_ops.cond(
249        # if t < min_seq_len: calculate and return everything
250        time < min_sequence_length,
251        lambda: flat_new_output + flat_new_state,
252        # else copy some of it through
253        lambda: _copy_some_through(flat_new_output, flat_new_state))
254
255  # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
256  # but benefits from removing cond() and its gradient.  We should
257  # profile with and without this switch here.
258  if skip_conditionals:
259    # Instead of using conditionals, perform the selective copy at all time
260    # steps.  This is faster when max_seq_len is equal to the number of unrolls
261    # (which is typical for dynamic_rnn).
262    new_output, new_state = call_cell()
263    nest.assert_same_structure(zero_output, new_output)
264    nest.assert_same_structure(state, new_state)
265    new_state = nest.flatten(new_state)
266    new_output = nest.flatten(new_output)
267    final_output_and_state = _copy_some_through(new_output, new_state)
268  else:
269    empty_update = lambda: flat_zero_output + flat_state
270    final_output_and_state = control_flow_ops.cond(
271        # if t >= max_seq_len: copy all state through, output zeros
272        time >= max_sequence_length,
273        empty_update,
274        # otherwise calculation is required: copy some or all of it through
275        _maybe_copy_some_through)
276
277  if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
278    raise ValueError("Internal error: state and output were not concatenated "
279                     f"correctly. Received state length: {len(flat_state)}, "
280                     f"output length: {len(flat_zero_output)}. Expected "
281                     f"contatenated length: {len(final_output_and_state)}.")
282  final_output = final_output_and_state[:len(flat_zero_output)]
283  final_state = final_output_and_state[len(flat_zero_output):]
284
285  for output, flat_output in zip(final_output, flat_zero_output):
286    output.set_shape(flat_output.get_shape())
287  for substate, flat_substate in zip(final_state, flat_state):
288    if not isinstance(substate, tensor_array_ops.TensorArray):
289      substate.set_shape(flat_substate.get_shape())
290
291  final_output = nest.pack_sequence_as(
292      structure=zero_output, flat_sequence=final_output)
293  final_state = nest.pack_sequence_as(
294      structure=state, flat_sequence=final_state)
295
296  return final_output, final_state
297
298
299def _reverse_seq(input_seq, lengths):
300  """Reverse a list of Tensors up to specified lengths.
301
302  Args:
303    input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
304      or nested tuples of tensors.
305    lengths:   A `Tensor` of dimension batch_size, containing lengths for each
306      sequence in the batch. If "None" is specified, simply reverses the list.
307
308  Returns:
309    time-reversed sequence
310  """
311  if lengths is None:
312    return list(reversed(input_seq))
313
314  flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)
315
316  flat_results = [[] for _ in range(len(input_seq))]
317  for sequence in zip(*flat_input_seq):
318    input_shape = tensor_shape.unknown_shape(rank=sequence[0].get_shape().rank)
319    for input_ in sequence:
320      input_shape.assert_is_compatible_with(input_.get_shape())
321      input_.set_shape(input_shape)
322
323    # Join into (time, batch_size, depth)
324    s_joined = array_ops.stack(sequence)
325
326    # Reverse along dimension 0
327    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
328    # Split again into list
329    result = array_ops.unstack(s_reversed)
330    for r, flat_result in zip(result, flat_results):
331      r.set_shape(input_shape)
332      flat_result.append(r)
333
334  results = [
335      nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
336      for input_, flat_result in zip(input_seq, flat_results)
337  ]
338  return results
339
340
341@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
342                        "keras.layers.RNN(cell))`, which is equivalent to "
343                        "this API")
344@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
345@dispatch.add_dispatch_support
346def bidirectional_dynamic_rnn(cell_fw,
347                              cell_bw,
348                              inputs,
349                              sequence_length=None,
350                              initial_state_fw=None,
351                              initial_state_bw=None,
352                              dtype=None,
353                              parallel_iterations=None,
354                              swap_memory=False,
355                              time_major=False,
356                              scope=None):
357  """Creates a dynamic version of bidirectional recurrent neural network.
358
359  Takes input and builds independent forward and backward RNNs. The input_size
360  of forward and backward cell must match. The initial state for both directions
361  is zero by default (but can be set optionally) and no intermediate states are
362  ever returned -- the network is fully unrolled for the given (passed in)
363  length(s) of the sequence(s) or completely unrolled if length(s) is not
364  given.
365
366  Args:
367    cell_fw: An instance of RNNCell, to be used for forward direction.
368    cell_bw: An instance of RNNCell, to be used for backward direction.
369    inputs: The RNN inputs.
370      If time_major == False (default), this must be a tensor of shape:
371        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
372      If time_major == True, this must be a tensor of shape: `[max_time,
373        batch_size, ...]`, or a nested tuple of such elements.
374    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
375      containing the actual lengths for each of the sequences in the batch. If
376      not provided, all batch entries are assumed to be full sequences; and time
377      reversal is applied from time `0` to `max_time` for each sequence.
378    initial_state_fw: (optional) An initial state for the forward RNN. This must
379      be a tensor of appropriate type and shape `[batch_size,
380      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
381      tuple of tensors having shapes `[batch_size, s] for s in
382      cell_fw.state_size`.
383    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
384      corresponding properties of `cell_bw`.
385    dtype: (optional) The data type for the initial states and expected output.
386      Required if initial_states are not provided or RNN states have a
387      heterogeneous dtype.
388    parallel_iterations: (Default: 32).  The number of iterations to run in
389      parallel.  Those operations which do not have any temporal dependency and
390      can be run in parallel, will be.  This parameter trades off time for
391      space.  Values >> 1 use more memory but take less time, while smaller
392      values use less memory but computations take longer.
393    swap_memory: Transparently swap the tensors produced in forward inference
394      but needed for back prop from GPU to CPU.  This allows training RNNs which
395      would typically not fit on a single GPU, with very minimal (or no)
396      performance penalty.
397    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
398      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
399      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
400      `time_major = True` is a bit more efficient because it avoids transposes
401      at the beginning and end of the RNN calculation.  However, most TensorFlow
402      data is batch-major, so by default this function accepts input and emits
403      output in batch-major form.
404    scope: VariableScope for the created subgraph; defaults to
405      "bidirectional_rnn"
406
407  Returns:
408    A tuple (outputs, output_states) where:
409      outputs: A tuple (output_fw, output_bw) containing the forward and
410        the backward rnn output `Tensor`.
411        If time_major == False (default),
412          output_fw will be a `Tensor` shaped:
413          `[batch_size, max_time, cell_fw.output_size]`
414          and output_bw will be a `Tensor` shaped:
415          `[batch_size, max_time, cell_bw.output_size]`.
416        If time_major == True,
417          output_fw will be a `Tensor` shaped:
418          `[max_time, batch_size, cell_fw.output_size]`
419          and output_bw will be a `Tensor` shaped:
420          `[max_time, batch_size, cell_bw.output_size]`.
421        It returns a tuple instead of a single concatenated `Tensor`, unlike
422        in the `bidirectional_rnn`. If the concatenated one is preferred,
423        the forward and backward outputs can be concatenated as
424        `tf.concat(outputs, 2)`.
425      output_states: A tuple (output_state_fw, output_state_bw) containing
426        the forward and the backward final states of bidirectional rnn.
427
428  Raises:
429    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
430  """
431  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
432  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
433
434  with vs.variable_scope(scope or "bidirectional_rnn"):
435    # Forward direction
436    with vs.variable_scope("fw") as fw_scope:
437      output_fw, output_state_fw = dynamic_rnn(
438          cell=cell_fw,
439          inputs=inputs,
440          sequence_length=sequence_length,
441          initial_state=initial_state_fw,
442          dtype=dtype,
443          parallel_iterations=parallel_iterations,
444          swap_memory=swap_memory,
445          time_major=time_major,
446          scope=fw_scope)
447
448    # Backward direction
449    if not time_major:
450      time_axis = 1
451      batch_axis = 0
452    else:
453      time_axis = 0
454      batch_axis = 1
455
456    def _reverse(input_, seq_lengths, seq_axis, batch_axis):
457      if seq_lengths is not None:
458        return array_ops.reverse_sequence(
459            input=input_,
460            seq_lengths=seq_lengths,
461            seq_axis=seq_axis,
462            batch_axis=batch_axis)
463      else:
464        return array_ops.reverse(input_, axis=[seq_axis])
465
466    with vs.variable_scope("bw") as bw_scope:
467
468      def _map_reverse(inp):
469        return _reverse(
470            inp,
471            seq_lengths=sequence_length,
472            seq_axis=time_axis,
473            batch_axis=batch_axis)
474
475      inputs_reverse = nest.map_structure(_map_reverse, inputs)
476      tmp, output_state_bw = dynamic_rnn(
477          cell=cell_bw,
478          inputs=inputs_reverse,
479          sequence_length=sequence_length,
480          initial_state=initial_state_bw,
481          dtype=dtype,
482          parallel_iterations=parallel_iterations,
483          swap_memory=swap_memory,
484          time_major=time_major,
485          scope=bw_scope)
486
487  output_bw = _reverse(
488      tmp,
489      seq_lengths=sequence_length,
490      seq_axis=time_axis,
491      batch_axis=batch_axis)
492
493  outputs = (output_fw, output_bw)
494  output_states = (output_state_fw, output_state_bw)
495
496  return (outputs, output_states)
497
498
499@deprecation.deprecated(
500    None,
501    "Please use `keras.layers.RNN(cell)`, which is equivalent to this API")
502@tf_export(v1=["nn.dynamic_rnn"])
503@dispatch.add_dispatch_support
504def dynamic_rnn(cell,
505                inputs,
506                sequence_length=None,
507                initial_state=None,
508                dtype=None,
509                parallel_iterations=None,
510                swap_memory=False,
511                time_major=False,
512                scope=None):
513  """Creates a recurrent neural network specified by RNNCell `cell`.
514
515  Performs fully dynamic unrolling of `inputs`.
516
517  Example:
518
519  ```python
520  # create a BasicRNNCell
521  rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size)
522
523  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
524
525  # defining initial state
526  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
527
528  # 'state' is a tensor of shape [batch_size, cell_state_size]
529  outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data,
530                                     initial_state=initial_state,
531                                     dtype=tf.float32)
532  ```
533
534  ```python
535  # create 2 LSTMCells
536  rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
537
538  # create a RNN cell composed sequentially of a number of RNNCells
539  multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
540
541  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
542  # 'state' is a N-tuple where N is the number of LSTMCells containing a
543  # tf.nn.rnn_cell.LSTMStateTuple for each cell
544  outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
545                                     inputs=data,
546                                     dtype=tf.float32)
547  ```
548
549
550  Args:
551    cell: An instance of RNNCell.
552    inputs: The RNN inputs.
553      If `time_major == False` (default), this must be a `Tensor` of shape:
554        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
555      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
556        batch_size, ...]`, or a nested tuple of such elements. This may also be
557        a (possibly nested) tuple of Tensors satisfying this property.  The
558        first two dimensions must match across all the inputs, but otherwise the
559        ranks and other shape components may differ. In this case, input to
560        `cell` at each time-step will replicate the structure of these tuples,
561        except for the time dimension (from which the time is taken). The input
562        to `cell` at each time step will be a `Tensor` or (possibly nested)
563        tuple of Tensors each with dimensions `[batch_size, ...]`.
564    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
565      to copy-through state and zero-out outputs when past a batch element's
566      sequence length.  This parameter enables users to extract the last valid
567      state and properly padded outputs, so it is provided for correctness.
568    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
569      is an integer, this must be a `Tensor` of appropriate type and shape
570      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
571      should be a tuple of tensors having shapes `[batch_size, s] for s in
572      cell.state_size`.
573    dtype: (optional) The data type for the initial state and expected output.
574      Required if initial_state is not provided or RNN state has a heterogeneous
575      dtype.
576    parallel_iterations: (Default: 32).  The number of iterations to run in
577      parallel.  Those operations which do not have any temporal dependency and
578      can be run in parallel, will be.  This parameter trades off time for
579      space.  Values >> 1 use more memory but take less time, while smaller
580      values use less memory but computations take longer.
581    swap_memory: Transparently swap the tensors produced in forward inference
582      but needed for back prop from GPU to CPU.  This allows training RNNs which
583      would typically not fit on a single GPU, with very minimal (or no)
584      performance penalty.
585    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
586      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
587      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
588      `time_major = True` is a bit more efficient because it avoids transposes
589      at the beginning and end of the RNN calculation.  However, most TensorFlow
590      data is batch-major, so by default this function accepts input and emits
591      output in batch-major form.
592    scope: VariableScope for the created subgraph; defaults to "rnn".
593
594  Returns:
595    A pair (outputs, state) where:
596
597    outputs: The RNN output `Tensor`.
598
599      If time_major == False (default), this will be a `Tensor` shaped:
600        `[batch_size, max_time, cell.output_size]`.
601
602      If time_major == True, this will be a `Tensor` shaped:
603        `[max_time, batch_size, cell.output_size]`.
604
605      Note, if `cell.output_size` is a (possibly nested) tuple of integers
606      or `TensorShape` objects, then `outputs` will be a tuple having the
607      same structure as `cell.output_size`, containing Tensors having shapes
608      corresponding to the shape data in `cell.output_size`.
609
610    state: The final state.  If `cell.state_size` is an int, this
611      will be shaped `[batch_size, cell.state_size]`.  If it is a
612      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
613      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
614      be a tuple having the corresponding shapes. If cells are `LSTMCells`
615      `state` will be a tuple containing a `LSTMStateTuple` for each cell.
616
617  Raises:
618    TypeError: If `cell` is not an instance of RNNCell.
619    ValueError: If inputs is None or an empty list.
620
621  @compatibility(TF2)
622  `tf.compat.v1.nn.dynamic_rnn` is not compatible with eager execution and
623  `tf.function`. Please use `tf.keras.layers.RNN` instead for TF2 migration.
624  Take LSTM as an example, you can instantiate a `tf.keras.layers.RNN` layer
625  with `tf.keras.layers.LSTMCell`, or directly via `tf.keras.layers.LSTM`. Once
626  the keras layer is created, you can get the output and states by calling
627  the layer with input and states. Please refer to [this
628  guide](https://www.tensorflow.org/guide/keras/rnn) for more details about
629  Keras RNN. You can also find more details about the difference and comparison
630  between Keras RNN and TF compat v1 rnn in [this
631  document](https://github.com/tensorflow/community/blob/master/rfcs/20180920-unify-rnn-interface.md)
632
633  #### Structural Mapping to Native TF2
634
635  Before:
636
637  ```python
638  # create 2 LSTMCells
639  rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
640
641  # create a RNN cell composed sequentially of a number of RNNCells
642  multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers)
643
644  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
645  # 'state' is a N-tuple where N is the number of LSTMCells containing a
646  # tf.nn.rnn_cell.LSTMStateTuple for each cell
647  outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell,
648                                               inputs=data,
649                                               dtype=tf.float32)
650  ```
651
652  After:
653
654  ```python
655  # RNN layer can take a list of cells, which will then stack them together.
656  # By default, keras RNN will only return the last timestep output and will not
657  # return states. If you need whole time sequence output as well as the states,
658  # you can set `return_sequences` and `return_state` to True.
659  rnn_layer = tf.keras.layers.RNN([tf.keras.layers.LSTMCell(128),
660                                   tf.keras.layers.LSTMCell(256)],
661                                  return_sequences=True,
662                                  return_state=True)
663  outputs, output_states = rnn_layer(inputs, states)
664  ```
665
666  #### How to Map Arguments
667
668  | TF1 Arg Name          | TF2 Arg Name    | Note                             |
669  | :-------------------- | :-------------- | :------------------------------- |
670  | `cell`                | `cell`          | In the RNN layer constructor     |
671  | `inputs`              | `inputs`        | In the RNN layer `__call__`      |
672  | `sequence_length`     | Not used        | Adding masking layer before RNN  :
673  :                       :                 : to achieve the same result.      :
674  | `initial_state`       | `initial_state` | In the RNN layer `__call__`      |
675  | `dtype`               | `dtype`         | In the RNN layer constructor     |
676  | `parallel_iterations` | Not supported   |                                  |
677  | `swap_memory`         | Not supported   |                                  |
678  | `time_major`          | `time_major`    | In the RNN layer constructor     |
679  | `scope`               | Not supported   |                                  |
680  @end_compatibility
681  """
682  rnn_cell_impl.assert_like_rnncell("cell", cell)
683
684  with vs.variable_scope(scope or "rnn") as varscope:
685    # Create a new scope in which the caching device is either
686    # determined by the parent scope, or is set to place the cached
687    # Variable using the same placement as for the rest of the RNN.
688    if _should_cache():
689      if varscope.caching_device is None:
690        varscope.set_caching_device(lambda op: op.device)
691
692    # By default, time_major==False and inputs are batch-major: shaped
693    #   [batch, time, depth]
694    # For internal calculations, we transpose to [time, batch, depth]
695    flat_input = nest.flatten(inputs)
696
697    if not time_major:
698      # (B,T,D) => (T,B,D)
699      flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
700      flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
701
702    parallel_iterations = parallel_iterations or 32
703    if sequence_length is not None:
704      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
705      if sequence_length.get_shape().rank not in (None, 1):
706        raise ValueError(
707            f"Argument sequence_length must be a vector of length batch_size."
708            f" Received sequence_length={sequence_length} of shape: "
709            f"{sequence_length.get_shape()}")
710      sequence_length = array_ops.identity(  # Just to find it in the graph.
711          sequence_length,
712          name="sequence_length")
713
714    batch_size = _best_effort_input_batch_size(flat_input)
715
716    if initial_state is not None:
717      state = initial_state
718    else:
719      if not dtype:
720        raise ValueError("If no initial_state is provided, argument `dtype` "
721                         "must be specified")
722      if getattr(cell, "get_initial_state", None) is not None:
723        state = cell.get_initial_state(
724            inputs=None, batch_size=batch_size, dtype=dtype)
725      else:
726        state = cell.zero_state(batch_size, dtype)
727
728    def _assert_has_shape(x, shape):
729      x_shape = array_ops.shape(x)
730      packed_shape = array_ops.stack(shape)
731      return control_flow_ops.Assert(
732          math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
733              "Expected shape for Tensor %s is " % x.name, packed_shape,
734              " but saw shape: ", x_shape
735          ])
736
737    if not context.executing_eagerly() and sequence_length is not None:
738      # Perform some shape validation
739      with ops.control_dependencies(
740          [_assert_has_shape(sequence_length, [batch_size])]):
741        sequence_length = array_ops.identity(
742            sequence_length, name="CheckSeqLen")
743
744    inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
745
746    (outputs, final_state) = _dynamic_rnn_loop(
747        cell,
748        inputs,
749        state,
750        parallel_iterations=parallel_iterations,
751        swap_memory=swap_memory,
752        sequence_length=sequence_length,
753        dtype=dtype)
754
755    # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
756    # If we are performing batch-major calculations, transpose output back
757    # to shape [batch, time, depth]
758    if not time_major:
759      # (T,B,D) => (B,T,D)
760      outputs = nest.map_structure(_transpose_batch_time, outputs)
761
762    return (outputs, final_state)
763
764
765def _dynamic_rnn_loop(cell,
766                      inputs,
767                      initial_state,
768                      parallel_iterations,
769                      swap_memory,
770                      sequence_length=None,
771                      dtype=None):
772  """Internal implementation of Dynamic RNN.
773
774  Args:
775    cell: An instance of RNNCell.
776    inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
777      tuple of such elements.
778    initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
779      `cell.state_size` is a tuple, then this should be a tuple of tensors
780      having shapes `[batch_size, s] for s in cell.state_size`.
781    parallel_iterations: Positive Python int.
782    swap_memory: A Python boolean
783    sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
784    dtype: (optional) Expected dtype of output. If not specified, inferred from
785      initial_state.
786
787  Returns:
788    Tuple `(final_outputs, final_state)`.
789    final_outputs:
790      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
791      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
792      objects, then this returns a (possibly nested) tuple of Tensors matching
793      the corresponding shapes.
794    final_state:
795      A `Tensor`, or possibly nested tuple of Tensors, matching in length
796      and shapes to `initial_state`.
797
798  Raises:
799    ValueError: If the input depth cannot be inferred via shape inference
800      from the inputs.
801    ValueError: If time_step is not the same for all the elements in the
802      inputs.
803    ValueError: If batch_size is not the same for all the elements in the
804      inputs.
805  """
806  state = initial_state
807  assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
808
809  state_size = cell.state_size
810
811  flat_input = nest.flatten(inputs)
812  flat_output_size = nest.flatten(cell.output_size)
813
814  # Construct an initial output
815  input_shape = array_ops.shape(flat_input[0])
816  time_steps = input_shape[0]
817  batch_size = _best_effort_input_batch_size(flat_input)
818
819  inputs_got_shape = tuple(
820      input_.get_shape().with_rank_at_least(3) for input_ in flat_input)
821
822  const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]
823
824  for i, shape in enumerate(inputs_got_shape):
825    if not shape[2:].is_fully_defined():
826      raise ValueError(
827          "Input size (depth of inputs) must be accessible via shape inference,"
828          f" but saw value None for input={flat_input[i]}.")
829    got_time_steps = shape.dims[0].value
830    got_batch_size = shape.dims[1].value
831    if const_time_steps != got_time_steps:
832      raise ValueError(
833          "Time steps is not the same for all the elements in the input in a "
834          f"batch. Received time steps={got_time_steps} for input="
835          f"{flat_input[i]}.")
836    if const_batch_size != got_batch_size:
837      raise ValueError(
838          "Batch_size is not the same for all the elements in the input. "
839          f"Received batch size={got_batch_size} for input={flat_input[i]}.")
840
841  # Prepare dynamic conditional copying of state & output
842  def _create_zero_arrays(size):
843    size = _concat(batch_size, size)
844    return array_ops.zeros(
845        array_ops.stack(size), _infer_state_dtype(dtype, state))
846
847  flat_zero_output = tuple(
848      _create_zero_arrays(output) for output in flat_output_size)
849  zero_output = nest.pack_sequence_as(
850      structure=cell.output_size, flat_sequence=flat_zero_output)
851
852  if sequence_length is not None:
853    min_sequence_length = math_ops.reduce_min(sequence_length)
854    max_sequence_length = math_ops.reduce_max(sequence_length)
855  else:
856    max_sequence_length = time_steps
857
858  time = array_ops.constant(0, dtype=dtypes.int32, name="time")
859
860  with ops.name_scope("dynamic_rnn") as scope:
861    base_name = scope
862
863  def _create_ta(name, element_shape, dtype):
864    return tensor_array_ops.TensorArray(
865        dtype=dtype,
866        size=time_steps,
867        element_shape=element_shape,
868        tensor_array_name=base_name + name)
869
870  in_graph_mode = not context.executing_eagerly()
871  if in_graph_mode:
872    output_ta = tuple(
873        _create_ta(
874            "output_%d" % i,
875            element_shape=(
876                tensor_shape.TensorShape([const_batch_size]).concatenate(
877                    _maybe_tensor_shape_from_tensor(out_size))),
878            dtype=_infer_state_dtype(dtype, state))
879        for i, out_size in enumerate(flat_output_size))
880    input_ta = tuple(
881        _create_ta(
882            "input_%d" % i,
883            element_shape=flat_input_i.shape[1:],
884            dtype=flat_input_i.dtype)
885        for i, flat_input_i in enumerate(flat_input))
886    input_ta = tuple(
887        ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input))
888  else:
889    output_ta = tuple([0 for _ in range(time_steps.numpy())]
890                      for i in range(len(flat_output_size)))
891    input_ta = flat_input
892
893  def _time_step(time, output_ta_t, state):
894    """Take a time step of the dynamic RNN.
895
896    Args:
897      time: int32 scalar Tensor.
898      output_ta_t: List of `TensorArray`s that represent the output.
899      state: nested tuple of vector tensors that represent the state.
900
901    Returns:
902      The tuple (time + 1, output_ta_t with updated flow, new_state).
903    """
904
905    if in_graph_mode:
906      input_t = tuple(ta.read(time) for ta in input_ta)
907      # Restore some shape information
908      for input_, shape in zip(input_t, inputs_got_shape):
909        input_.set_shape(shape[1:])
910    else:
911      input_t = tuple(ta[time.numpy()] for ta in input_ta)
912
913    input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
914    # Keras RNN cells only accept state as list, even if it's a single tensor.
915    call_cell = lambda: cell(input_t, state)
916
917    if sequence_length is not None:
918      (output, new_state) = _rnn_step(
919          time=time,
920          sequence_length=sequence_length,
921          min_sequence_length=min_sequence_length,
922          max_sequence_length=max_sequence_length,
923          zero_output=zero_output,
924          state=state,
925          call_cell=call_cell,
926          state_size=state_size,
927          skip_conditionals=True)
928    else:
929      (output, new_state) = call_cell()
930
931    # Pack state if using state tuples
932    output = nest.flatten(output)
933
934    if in_graph_mode:
935      output_ta_t = tuple(
936          ta.write(time, out) for ta, out in zip(output_ta_t, output))
937    else:
938      for ta, out in zip(output_ta_t, output):
939        ta[time.numpy()] = out
940
941    return (time + 1, output_ta_t, new_state)
942
943  if in_graph_mode:
944    # Make sure that we run at least 1 step, if necessary, to ensure
945    # the TensorArrays pick up the dynamic shape.
946    loop_bound = math_ops.minimum(time_steps,
947                                  math_ops.maximum(1, max_sequence_length))
948  else:
949    # Using max_sequence_length isn't currently supported in the Eager branch.
950    loop_bound = time_steps
951
952  _, output_final_ta, final_state = control_flow_ops.while_loop(
953      cond=lambda time, *_: time < loop_bound,
954      body=_time_step,
955      loop_vars=(time, output_ta, state),
956      parallel_iterations=parallel_iterations,
957      maximum_iterations=time_steps,
958      swap_memory=swap_memory)
959
960  # Unpack final output if not using output tuples.
961  if in_graph_mode:
962    final_outputs = tuple(ta.stack() for ta in output_final_ta)
963    # Restore some shape information
964    for output, output_size in zip(final_outputs, flat_output_size):
965      shape = _concat([const_time_steps, const_batch_size],
966                      output_size,
967                      static=True)
968      output.set_shape(shape)
969  else:
970    final_outputs = output_final_ta
971
972  final_outputs = nest.pack_sequence_as(
973      structure=cell.output_size, flat_sequence=final_outputs)
974  if not in_graph_mode:
975    final_outputs = nest.map_structure_up_to(
976        cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
977
978  return (final_outputs, final_state)
979
980
981@tf_export(v1=["nn.raw_rnn"])
982@dispatch.add_dispatch_support
983def raw_rnn(cell,
984            loop_fn,
985            parallel_iterations=None,
986            swap_memory=False,
987            scope=None):
988  """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
989
990  **NOTE: This method is still in testing, and the API may change.**
991
992  This function is a more primitive version of `dynamic_rnn` that provides
993  more direct access to the inputs each iteration.  It also provides more
994  control over when to start and finish reading the sequence, and
995  what to emit for the output.
996
997  For example, it can be used to implement the dynamic decoder of a seq2seq
998  model.
999
1000  Instead of working with `Tensor` objects, most operations work with
1001  `TensorArray` objects directly.
1002
1003  The operation of `raw_rnn`, in pseudo-code, is basically the following:
1004
1005  ```python
1006  time = tf.constant(0, dtype=tf.int32)
1007  (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
1008      time=time, cell_output=None, cell_state=None, loop_state=None)
1009  emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
1010  state = initial_state
1011  while not all(finished):
1012    (output, cell_state) = cell(next_input, state)
1013    (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
1014        time=time + 1, cell_output=output, cell_state=cell_state,
1015        loop_state=loop_state)
1016    # Emit zeros and copy forward state for minibatch entries that are finished.
1017    state = tf.where(finished, state, next_state)
1018    emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
1019    emit_ta = emit_ta.write(time, emit)
1020    # If any new minibatch entries are marked as finished, mark these.
1021    finished = tf.logical_or(finished, next_finished)
1022    time += 1
1023  return (emit_ta, state, loop_state)
1024  ```
1025
1026  with the additional properties that output and state may be (possibly nested)
1027  tuples, as determined by `cell.output_size` and `cell.state_size`, and
1028  as a result the final `state` and `emit_ta` may themselves be tuples.
1029
1030  A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:
1031
1032  ```python
1033  inputs = tf.compat.v1.placeholder(shape=(max_time, batch_size, input_depth),
1034                          dtype=tf.float32)
1035  sequence_length = tf.compat.v1.placeholder(shape=(batch_size,),
1036  dtype=tf.int32)
1037  inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
1038  inputs_ta = inputs_ta.unstack(inputs)
1039
1040  cell = tf.compat.v1.nn.rnn_cell.LSTMCell(num_units)
1041
1042  def loop_fn(time, cell_output, cell_state, loop_state):
1043    emit_output = cell_output  # == None for time == 0
1044    if cell_output is None:  # time == 0
1045      next_cell_state = cell.zero_state(batch_size, tf.float32)
1046    else:
1047      next_cell_state = cell_state
1048    elements_finished = (time >= sequence_length)
1049    finished = tf.reduce_all(elements_finished)
1050    next_input = tf.cond(
1051        finished,
1052        lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
1053        lambda: inputs_ta.read(time))
1054    next_loop_state = None
1055    return (elements_finished, next_input, next_cell_state,
1056            emit_output, next_loop_state)
1057
1058  outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
1059  outputs = outputs_ta.stack()
1060  ```
1061
1062  Args:
1063    cell: An instance of RNNCell.
1064    loop_fn: A callable that takes inputs `(time, cell_output, cell_state,
1065      loop_state)` and returns the tuple `(finished, next_input,
1066      next_cell_state, emit_output, next_loop_state)`. Here `time` is an int32
1067      scalar `Tensor`, `cell_output` is a `Tensor` or (possibly nested) tuple of
1068      tensors as determined by `cell.output_size`, and `cell_state` is a
1069      `Tensor` or (possibly nested) tuple of tensors, as determined by the
1070      `loop_fn` on its first call (and should match `cell.state_size`).
1071      The outputs are: `finished`, a boolean `Tensor` of
1072      shape `[batch_size]`, `next_input`: the next input to feed to `cell`,
1073      `next_cell_state`: the next state to feed to `cell`,
1074      and `emit_output`: the output to store for this iteration.  Note that
1075        `emit_output` should be a `Tensor` or (possibly nested) tuple of tensors
1076        which is aggregated in the `emit_ta` inside the `while_loop`. For the
1077        first call to `loop_fn`, the `emit_output` corresponds to the
1078        `emit_structure` which is then used to determine the size of the
1079        `zero_tensor` for the `emit_ta` (defaults to `cell.output_size`). For
1080        the subsequent calls to the `loop_fn`, the `emit_output` corresponds to
1081        the actual output tensor that is to be aggregated in the `emit_ta`. The
1082        parameter `cell_state` and output `next_cell_state` may be either a
1083        single or (possibly nested) tuple of tensors.  The parameter
1084        `loop_state` and output `next_loop_state` may be either a single or
1085        (possibly nested) tuple of `Tensor` and `TensorArray` objects.  This
1086        last parameter may be ignored by `loop_fn` and the return value may be
1087        `None`.  If it is not `None`, then the `loop_state` will be propagated
1088        through the RNN loop, for use purely by `loop_fn` to keep track of its
1089        own state. The `next_loop_state` parameter returned may be `None`.  The
1090        first call to `loop_fn` will be `time = 0`, `cell_output = None`,
1091      `cell_state = None`, and `loop_state = None`.  For this call: The
1092        `next_cell_state` value should be the value with which to initialize the
1093        cell's state.  It may be a final state from a previous RNN or it may be
1094        the output of `cell.zero_state()`.  It should be a (possibly nested)
1095        tuple structure of tensors. If `cell.state_size` is an integer, this
1096        must be a `Tensor` of appropriate type and shape `[batch_size,
1097        cell.state_size]`. If `cell.state_size` is a `TensorShape`, this must be
1098        a `Tensor` of appropriate type and shape `[batch_size] +
1099        cell.state_size`. If `cell.state_size` is a (possibly nested) tuple of
1100        ints or `TensorShape`, this will be a tuple having the corresponding
1101        shapes. The `emit_output` value may be either `None` or a (possibly
1102        nested) tuple structure of tensors, e.g., `(tf.zeros(shape_0,
1103        dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. If this first
1104        `emit_output` return value is `None`, then the `emit_ta` result of
1105        `raw_rnn` will have the same structure and dtypes as `cell.output_size`.
1106        Otherwise `emit_ta` will have the same structure, shapes (prepended with
1107        a `batch_size` dimension), and dtypes as `emit_output`.  The actual
1108        values returned for `emit_output` at this initializing call are ignored.
1109        Note, this emit structure must be consistent across all time steps.
1110    parallel_iterations: (Default: 32).  The number of iterations to run in
1111      parallel.  Those operations which do not have any temporal dependency and
1112      can be run in parallel, will be.  This parameter trades off time for
1113      space.  Values >> 1 use more memory but take less time, while smaller
1114      values use less memory but computations take longer.
1115    swap_memory: Transparently swap the tensors produced in forward inference
1116      but needed for back prop from GPU to CPU.  This allows training RNNs which
1117      would typically not fit on a single GPU, with very minimal (or no)
1118      performance penalty.
1119    scope: VariableScope for the created subgraph; defaults to "rnn".
1120
1121  Returns:
1122    A tuple `(emit_ta, final_state, final_loop_state)` where:
1123
1124    `emit_ta`: The RNN output `TensorArray`.
1125       If `loop_fn` returns a (possibly nested) set of Tensors for
1126       `emit_output` during initialization, (inputs `time = 0`,
1127       `cell_output = None`, and `loop_state = None`), then `emit_ta` will
1128       have the same structure, dtypes, and shapes as `emit_output` instead.
1129       If `loop_fn` returns `emit_output = None` during this call,
1130       the structure of `cell.output_size` is used:
1131       If `cell.output_size` is a (possibly nested) tuple of integers
1132       or `TensorShape` objects, then `emit_ta` will be a tuple having the
1133       same structure as `cell.output_size`, containing TensorArrays whose
1134       elements' shapes correspond to the shape data in `cell.output_size`.
1135
1136    `final_state`: The final cell state.  If `cell.state_size` is an int, this
1137      will be shaped `[batch_size, cell.state_size]`.  If it is a
1138      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
1139      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
1140      be a tuple having the corresponding shapes.
1141
1142    `final_loop_state`: The final loop state as returned by `loop_fn`.
1143
1144  Raises:
1145    TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
1146      a `callable`.
1147  """
1148  rnn_cell_impl.assert_like_rnncell("cell", cell)
1149
1150  if not callable(loop_fn):
1151    raise TypeError("Argument `loop_fn` must be a callable. Received: "
1152                    f"{loop_fn}.")
1153
1154  parallel_iterations = parallel_iterations or 32
1155
1156  # Create a new scope in which the caching device is either
1157  # determined by the parent scope, or is set to place the cached
1158  # Variable using the same placement as for the rest of the RNN.
1159  with vs.variable_scope(scope or "rnn") as varscope:
1160    if _should_cache():
1161      if varscope.caching_device is None:
1162        varscope.set_caching_device(lambda op: op.device)
1163
1164    time = constant_op.constant(0, dtype=dtypes.int32)
1165    (elements_finished, next_input,
1166     initial_state, emit_structure, init_loop_state) = loop_fn(
1167         time, None, None, None)  # time, cell_output, cell_state, loop_state
1168    flat_input = nest.flatten(next_input)
1169
1170    # Need a surrogate loop state for the while_loop if none is available.
1171    loop_state = (
1172        init_loop_state if init_loop_state is not None else
1173        constant_op.constant(0, dtype=dtypes.int32))
1174
1175    input_shape = [input_.get_shape() for input_ in flat_input]
1176    static_batch_size = tensor_shape.dimension_at_index(input_shape[0], 0)
1177
1178    for input_shape_i in input_shape:
1179      # Static verification that batch sizes all match
1180      static_batch_size.assert_is_compatible_with(
1181          tensor_shape.dimension_at_index(input_shape_i, 0))
1182
1183    batch_size = tensor_shape.dimension_value(static_batch_size)
1184    const_batch_size = batch_size
1185    if batch_size is None:
1186      batch_size = array_ops.shape(flat_input[0])[0]
1187
1188    nest.assert_same_structure(initial_state, cell.state_size)
1189    state = initial_state
1190    flat_state = nest.flatten(state)
1191    flat_state = [ops.convert_to_tensor(s) for s in flat_state]
1192    state = nest.pack_sequence_as(structure=state, flat_sequence=flat_state)
1193
1194    if emit_structure is not None:
1195      flat_emit_structure = nest.flatten(emit_structure)
1196      flat_emit_size = [
1197          emit.shape if emit.shape.is_fully_defined() else array_ops.shape(emit)
1198          for emit in flat_emit_structure
1199      ]
1200      flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
1201    else:
1202      emit_structure = cell.output_size
1203      flat_emit_size = nest.flatten(emit_structure)
1204      flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
1205
1206    flat_emit_ta = [
1207        tensor_array_ops.TensorArray(
1208            dtype=dtype_i,
1209            dynamic_size=True,
1210            element_shape=(tensor_shape.TensorShape([
1211                const_batch_size
1212            ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
1213            size=0,
1214            name="rnn_output_%d" % i)
1215        for i, (dtype_i,
1216                size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
1217    ]
1218    emit_ta = nest.pack_sequence_as(
1219        structure=emit_structure, flat_sequence=flat_emit_ta)
1220    flat_zero_emit = [
1221        array_ops.zeros(_concat(batch_size, size_i), dtype_i)
1222        for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
1223    ]
1224    zero_emit = nest.pack_sequence_as(
1225        structure=emit_structure, flat_sequence=flat_zero_emit)
1226
1227    def condition(unused_time, elements_finished, *_):
1228      return math_ops.logical_not(math_ops.reduce_all(elements_finished))
1229
1230    def body(time, elements_finished, current_input, emit_ta, state,
1231             loop_state):
1232      """Internal while loop body for raw_rnn.
1233
1234      Args:
1235        time: time scalar.
1236        elements_finished: batch-size vector.
1237        current_input: possibly nested tuple of input tensors.
1238        emit_ta: possibly nested tuple of output TensorArrays.
1239        state: possibly nested tuple of state tensors.
1240        loop_state: possibly nested tuple of loop state tensors.
1241
1242      Returns:
1243        Tuple having the same size as Args but with updated values.
1244      """
1245      (next_output, cell_state) = cell(current_input, state)
1246
1247      nest.assert_same_structure(state, cell_state)
1248      nest.assert_same_structure(cell.output_size, next_output)
1249
1250      next_time = time + 1
1251      (next_finished, next_input, next_state, emit_output,
1252       next_loop_state) = loop_fn(next_time, next_output, cell_state,
1253                                  loop_state)
1254
1255      nest.assert_same_structure(state, next_state)
1256      nest.assert_same_structure(current_input, next_input)
1257      nest.assert_same_structure(emit_ta, emit_output)
1258
1259      # If loop_fn returns None for next_loop_state, just reuse the
1260      # previous one.
1261      loop_state = loop_state if next_loop_state is None else next_loop_state
1262
1263      def _copy_some_through(current, candidate):
1264        """Copy some tensors through via array_ops.where."""
1265
1266        def copy_fn(cur_i, cand_i):
1267          # TensorArray and scalar get passed through.
1268          if isinstance(cur_i, tensor_array_ops.TensorArray):
1269            return cand_i
1270          if cur_i.shape.rank == 0:
1271            return cand_i
1272          # Otherwise propagate the old or the new value.
1273          with ops.colocate_with(cand_i):
1274            return array_ops.where(elements_finished, cur_i, cand_i)
1275
1276        return nest.map_structure(copy_fn, current, candidate)
1277
1278      emit_output = _copy_some_through(zero_emit, emit_output)
1279      next_state = _copy_some_through(state, next_state)
1280
1281      emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
1282                                   emit_ta, emit_output)
1283
1284      elements_finished = math_ops.logical_or(elements_finished, next_finished)
1285
1286      return (next_time, elements_finished, next_input, emit_ta, next_state,
1287              loop_state)
1288
1289    returned = control_flow_ops.while_loop(
1290        condition,
1291        body,
1292        loop_vars=[
1293            time, elements_finished, next_input, emit_ta, state, loop_state
1294        ],
1295        parallel_iterations=parallel_iterations,
1296        swap_memory=swap_memory)
1297
1298    (emit_ta, final_state, final_loop_state) = returned[-3:]
1299
1300    if init_loop_state is None:
1301      final_loop_state = None
1302
1303    return (emit_ta, final_state, final_loop_state)
1304
1305
1306@deprecation.deprecated(None,
1307                        "Please use `keras.layers.RNN(cell, unroll=True)`, "
1308                        "which is equivalent to this API")
1309@tf_export(v1=["nn.static_rnn"])
1310@dispatch.add_dispatch_support
1311def static_rnn(cell,
1312               inputs,
1313               initial_state=None,
1314               dtype=None,
1315               sequence_length=None,
1316               scope=None):
1317  """Creates a recurrent neural network specified by RNNCell `cell`.
1318
1319  The simplest form of RNN network generated is:
1320
1321  ```python
1322    state = cell.zero_state(...)
1323    outputs = []
1324    for input_ in inputs:
1325      output, state = cell(input_, state)
1326      outputs.append(output)
1327    return (outputs, state)
1328  ```
1329  However, a few other options are available:
1330
1331  An initial state can be provided.
1332  If the sequence_length vector is provided, dynamic calculation is performed.
1333  This method of calculation does not compute the RNN steps past the maximum
1334  sequence length of the minibatch (thus saving computational time),
1335  and properly propagates the state at an example's sequence length
1336  to the final state output.
1337
1338  The dynamic calculation performed is, at time `t` for batch row `b`,
1339
1340  ```python
1341    (output, state)(b, t) =
1342      (t >= sequence_length(b))
1343        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
1344        : cell(input(b, t), state(b, t - 1))
1345  ```
1346
1347  Args:
1348    cell: An instance of RNNCell.
1349    inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size,
1350      input_size]`, or a nested tuple of such elements.
1351    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
1352      is an integer, this must be a `Tensor` of appropriate type and shape
1353      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
1354      should be a tuple of tensors having shapes `[batch_size, s] for s in
1355      cell.state_size`.
1356    dtype: (optional) The data type for the initial state and expected output.
1357      Required if initial_state is not provided or RNN state has a heterogeneous
1358      dtype.
1359    sequence_length: Specifies the length of each sequence in inputs. An int32
1360      or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
1361    scope: VariableScope for the created subgraph; defaults to "rnn".
1362
1363  Returns:
1364    A pair (outputs, state) where:
1365
1366    - outputs is a length T list of outputs (one for each input), or a nested
1367      tuple of such elements.
1368    - state is the final state
1369
1370  Raises:
1371    TypeError: If `cell` is not an instance of RNNCell.
1372    ValueError: If `inputs` is `None` or an empty list, or if the input depth
1373      (column size) cannot be inferred from inputs via shape inference.
1374  """
1375  rnn_cell_impl.assert_like_rnncell("cell", cell)
1376  if not nest.is_nested(inputs):
1377    raise TypeError(f"Argument `inputs` must be a sequence. Received: {inputs}")
1378  if not inputs:
1379    raise ValueError("Argument `inputs` must not be empty.")
1380
1381  outputs = []
1382  # Create a new scope in which the caching device is either
1383  # determined by the parent scope, or is set to place the cached
1384  # Variable using the same placement as for the rest of the RNN.
1385  with vs.variable_scope(scope or "rnn") as varscope:
1386    if _should_cache():
1387      if varscope.caching_device is None:
1388        varscope.set_caching_device(lambda op: op.device)
1389
1390    # Obtain the first sequence of the input
1391    first_input = inputs
1392    while nest.is_nested(first_input):
1393      first_input = first_input[0]
1394
1395    # Temporarily avoid EmbeddingWrapper and seq2seq badness
1396    # TODO(lukaszkaiser): remove EmbeddingWrapper
1397    if first_input.get_shape().rank != 1:
1398
1399      input_shape = first_input.get_shape().with_rank_at_least(2)
1400      fixed_batch_size = input_shape.dims[0]
1401
1402      flat_inputs = nest.flatten(inputs)
1403      for flat_input in flat_inputs:
1404        input_shape = flat_input.get_shape().with_rank_at_least(2)
1405        batch_size, input_size = tensor_shape.dimension_at_index(
1406            input_shape, 0), input_shape[1:]
1407        fixed_batch_size.assert_is_compatible_with(batch_size)
1408        for i, size in enumerate(input_size.dims):
1409          if tensor_shape.dimension_value(size) is None:
1410            raise ValueError(
1411                f"Input size (dimension {i} of input {flat_input}) must be "
1412                "accessible via shape inference, but saw value None.")
1413    else:
1414      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]
1415
1416    if tensor_shape.dimension_value(fixed_batch_size):
1417      batch_size = tensor_shape.dimension_value(fixed_batch_size)
1418    else:
1419      batch_size = array_ops.shape(first_input)[0]
1420    if initial_state is not None:
1421      state = initial_state
1422    else:
1423      if not dtype:
1424        raise ValueError("If no initial_state is provided, argument `dtype` "
1425                         "must be specified")
1426      if getattr(cell, "get_initial_state", None) is not None:
1427        state = cell.get_initial_state(
1428            inputs=None, batch_size=batch_size, dtype=dtype)
1429      else:
1430        state = cell.zero_state(batch_size, dtype)
1431
1432    if sequence_length is not None:  # Prepare variables
1433      sequence_length = ops.convert_to_tensor(
1434          sequence_length, name="sequence_length")
1435      if sequence_length.get_shape().rank not in (None, 1):
1436        raise ValueError(
1437            "Argument `sequence_length` must be a vector of length "
1438            f"{batch_size}. Received sequence_length={sequence_length}.")
1439
1440      def _create_zero_output(output_size):
1441        # convert int to TensorShape if necessary
1442        size = _concat(batch_size, output_size)
1443        output = array_ops.zeros(
1444            array_ops.stack(size), _infer_state_dtype(dtype, state))
1445        shape = _concat(
1446            tensor_shape.dimension_value(fixed_batch_size),
1447            output_size,
1448            static=True)
1449        output.set_shape(tensor_shape.TensorShape(shape))
1450        return output
1451
1452      output_size = cell.output_size
1453      flat_output_size = nest.flatten(output_size)
1454      flat_zero_output = tuple(
1455          _create_zero_output(size) for size in flat_output_size)
1456      zero_output = nest.pack_sequence_as(
1457          structure=output_size, flat_sequence=flat_zero_output)
1458
1459      sequence_length = math_ops.cast(sequence_length, dtypes.int32)
1460      min_sequence_length = math_ops.reduce_min(sequence_length)
1461      max_sequence_length = math_ops.reduce_max(sequence_length)
1462
1463    for time, input_ in enumerate(inputs):
1464      if time > 0:
1465        varscope.reuse_variables()
1466      # pylint: disable=cell-var-from-loop
1467      call_cell = lambda: cell(input_, state)
1468      # pylint: enable=cell-var-from-loop
1469      if sequence_length is not None:
1470        (output, state) = _rnn_step(
1471            time=time,
1472            sequence_length=sequence_length,
1473            min_sequence_length=min_sequence_length,
1474            max_sequence_length=max_sequence_length,
1475            zero_output=zero_output,
1476            state=state,
1477            call_cell=call_cell,
1478            state_size=cell.state_size)
1479      else:
1480        (output, state) = call_cell()
1481      outputs.append(output)
1482
1483    return (outputs, state)
1484
1485
1486@deprecation.deprecated(None,
1487                        "Please use `keras.layers.RNN(cell, stateful=True)`, "
1488                        "which is equivalent to this API")
1489@tf_export(v1=["nn.static_state_saving_rnn"])
1490@dispatch.add_dispatch_support
1491def static_state_saving_rnn(cell,
1492                            inputs,
1493                            state_saver,
1494                            state_name,
1495                            sequence_length=None,
1496                            scope=None):
1497  """RNN that accepts a state saver for time-truncated RNN calculation.
1498
1499  Args:
1500    cell: An instance of `RNNCell`.
1501    inputs: A length T list of inputs, each a `Tensor` of shape `[batch_size,
1502      input_size]`.
1503    state_saver: A state saver object with methods `state` and `save_state`.
1504    state_name: Python string or tuple of strings.  The name to use with the
1505      state_saver. If the cell returns tuples of states (i.e., `cell.state_size`
1506      is a tuple) then `state_name` should be a tuple of strings having the same
1507      length as `cell.state_size`.  Otherwise it should be a single string.
1508    sequence_length: (optional) An int32/int64 vector size [batch_size]. See the
1509      documentation for rnn() for more details about sequence_length.
1510    scope: VariableScope for the created subgraph; defaults to "rnn".
1511
1512  Returns:
1513    A pair (outputs, state) where:
1514      outputs is a length T list of outputs (one for each input)
1515      states is the final state
1516
1517  Raises:
1518    TypeError: If `cell` is not an instance of RNNCell.
1519    ValueError: If `inputs` is `None` or an empty list, or if the arity and
1520     type of `state_name` does not match that of `cell.state_size`.
1521  """
1522  state_size = cell.state_size
1523  state_is_tuple = nest.is_nested(state_size)
1524  state_name_tuple = nest.is_nested(state_name)
1525
1526  if state_is_tuple != state_name_tuple:
1527    raise ValueError("Argument `state_name` should be the same type as "
1528                     f"`cell.state_size`. Received: state_name={state_name!s}, "
1529                     f"cell.state_size={state_size!s}.")
1530
1531  if state_is_tuple:
1532    state_name_flat = nest.flatten(state_name)
1533    state_size_flat = nest.flatten(state_size)
1534
1535    if len(state_name_flat) != len(state_size_flat):
1536      raise ValueError("Number of elements in argument `state_name` and "
1537                       "`cell.state_size` are mismatched. Received "
1538                       f"state_name={state_name} with {len(state_name_flat)} "
1539                       f"elements and cell.state_size={cell.state_size} with "
1540                       f"{len(state_size_flat)} elements.")
1541
1542    initial_state = nest.pack_sequence_as(
1543        structure=state_size,
1544        flat_sequence=[state_saver.state(s) for s in state_name_flat])
1545  else:
1546    initial_state = state_saver.state(state_name)
1547
1548  (outputs, state) = static_rnn(
1549      cell,
1550      inputs,
1551      initial_state=initial_state,
1552      sequence_length=sequence_length,
1553      scope=scope)
1554
1555  if state_is_tuple:
1556    flat_state = nest.flatten(state)
1557    state_name = nest.flatten(state_name)
1558    save_state = [
1559        state_saver.save_state(name, substate)
1560        for name, substate in zip(state_name, flat_state)
1561    ]
1562  else:
1563    save_state = [state_saver.save_state(state_name, state)]
1564
1565  with ops.control_dependencies(save_state):
1566    last_output = outputs[-1]
1567    flat_last_output = nest.flatten(last_output)
1568    flat_last_output = [
1569        array_ops.identity(output) for output in flat_last_output
1570    ]
1571    outputs[-1] = nest.pack_sequence_as(
1572        structure=last_output, flat_sequence=flat_last_output)
1573
1574    if state_is_tuple:
1575      state = nest.pack_sequence_as(
1576          structure=state,
1577          flat_sequence=[array_ops.identity(s) for s in flat_state])
1578    else:
1579      state = array_ops.identity(state)
1580
1581  return (outputs, state)
1582
1583
1584@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
1585                        "keras.layers.RNN(cell, unroll=True))`, which is "
1586                        "equivalent to this API")
1587@tf_export(v1=["nn.static_bidirectional_rnn"])
1588@dispatch.add_dispatch_support
1589def static_bidirectional_rnn(cell_fw,
1590                             cell_bw,
1591                             inputs,
1592                             initial_state_fw=None,
1593                             initial_state_bw=None,
1594                             dtype=None,
1595                             sequence_length=None,
1596                             scope=None):
1597  """Creates a bidirectional recurrent neural network.
1598
1599  Similar to the unidirectional case above (rnn) but takes input and builds
1600  independent forward and backward RNNs with the final forward and backward
1601  outputs depth-concatenated, such that the output will have the format
1602  [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
1603  forward and backward cell must match. The initial state for both directions
1604  is zero by default (but can be set optionally) and no intermediate states are
1605  ever returned -- the network is fully unrolled for the given (passed in)
1606  length(s) of the sequence(s) or completely unrolled if length(s) is not given.
1607
1608  Args:
1609    cell_fw: An instance of RNNCell, to be used for forward direction.
1610    cell_bw: An instance of RNNCell, to be used for backward direction.
1611    inputs: A length T list of inputs, each a tensor of shape [batch_size,
1612      input_size], or a nested tuple of such elements.
1613    initial_state_fw: (optional) An initial state for the forward RNN. This must
1614      be a tensor of appropriate type and shape `[batch_size,
1615      cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a
1616      tuple of tensors having shapes `[batch_size, s] for s in
1617      cell_fw.state_size`.
1618    initial_state_bw: (optional) Same as for `initial_state_fw`, but using the
1619      corresponding properties of `cell_bw`.
1620    dtype: (optional) The data type for the initial state.  Required if either
1621      of the initial states are not provided.
1622    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
1623      containing the actual lengths for each of the sequences.
1624    scope: VariableScope for the created subgraph; defaults to
1625      "bidirectional_rnn"
1626
1627  Returns:
1628    A tuple (outputs, output_state_fw, output_state_bw) where:
1629      outputs is a length `T` list of outputs (one for each input), which
1630        are depth-concatenated forward and backward outputs.
1631      output_state_fw is the final state of the forward rnn.
1632      output_state_bw is the final state of the backward rnn.
1633
1634  Raises:
1635    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
1636    ValueError: If inputs is None or an empty list.
1637  """
1638  rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw)
1639  rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw)
1640  if not nest.is_nested(inputs):
1641    raise TypeError(f"Argument `inputs` must be a sequence. Received: {inputs}")
1642  if not inputs:
1643    raise ValueError("Argument `inputs` must not be empty.")
1644
1645  with vs.variable_scope(scope or "bidirectional_rnn"):
1646    # Forward direction
1647    with vs.variable_scope("fw") as fw_scope:
1648      output_fw, output_state_fw = static_rnn(
1649          cell_fw,
1650          inputs,
1651          initial_state_fw,
1652          dtype,
1653          sequence_length,
1654          scope=fw_scope)
1655
1656    # Backward direction
1657    with vs.variable_scope("bw") as bw_scope:
1658      reversed_inputs = _reverse_seq(inputs, sequence_length)
1659      tmp, output_state_bw = static_rnn(
1660          cell_bw,
1661          reversed_inputs,
1662          initial_state_bw,
1663          dtype,
1664          sequence_length,
1665          scope=bw_scope)
1666
1667  output_bw = _reverse_seq(tmp, sequence_length)
1668  # Concat each of the forward/backward outputs
1669  flat_output_fw = nest.flatten(output_fw)
1670  flat_output_bw = nest.flatten(output_bw)
1671
1672  flat_outputs = tuple(
1673      array_ops.concat([fw, bw], 1)
1674      for fw, bw in zip(flat_output_fw, flat_output_bw))
1675
1676  outputs = nest.pack_sequence_as(
1677      structure=output_fw, flat_sequence=flat_outputs)
1678
1679  return (outputs, output_state_fw, output_state_bw)
1680