• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""RNN helpers for TensorFlow models."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import rnn
22from tensorflow.python.ops import variable_scope as vs
23
24
25def stack_bidirectional_rnn(cells_fw,
26                            cells_bw,
27                            inputs,
28                            initial_states_fw=None,
29                            initial_states_bw=None,
30                            dtype=None,
31                            sequence_length=None,
32                            scope=None):
33  """Creates a bidirectional recurrent neural network.
34
35  Stacks several bidirectional rnn layers. The combined forward and backward
36  layer outputs are used as input of the next layer. tf.bidirectional_rnn
37  does not allow to share forward and backward information between layers.
38  The input_size of the first forward and backward cells must match.
39  The initial state for both directions is zero and no intermediate states
40  are returned.
41
42  As described in https://arxiv.org/abs/1303.5778
43
44  Args:
45    cells_fw: List of instances of RNNCell, one per layer,
46      to be used for forward direction.
47    cells_bw: List of instances of RNNCell, one per layer,
48      to be used for backward direction.
49    inputs: A length T list of inputs, each a tensor of shape
50      [batch_size, input_size], or a nested tuple of such elements.
51    initial_states_fw: (optional) A list of the initial states (one per layer)
52      for the forward RNN.
53      Each tensor must has an appropriate type and shape
54      `[batch_size, cell_fw.state_size]`.
55    initial_states_bw: (optional) Same as for `initial_states_fw`, but using
56      the corresponding properties of `cells_bw`.
57    dtype: (optional) The data type for the initial state.  Required if
58      either of the initial states are not provided.
59    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
60      containing the actual lengths for each of the sequences.
61    scope: VariableScope for the created subgraph; defaults to None.
62
63  Returns:
64    A tuple (outputs, output_state_fw, output_state_bw) where:
65      outputs is a length `T` list of outputs (one for each input), which
66        are depth-concatenated forward and backward outputs.
67      output_states_fw is the final states, one tensor per layer,
68        of the forward rnn.
69      output_states_bw is the final states, one tensor per layer,
70        of the backward rnn.
71
72  Raises:
73    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
74    ValueError: If inputs is None, not a list or an empty list.
75  """
76  if not cells_fw:
77    raise ValueError("Must specify at least one fw cell for BidirectionalRNN.")
78  if not cells_bw:
79    raise ValueError("Must specify at least one bw cell for BidirectionalRNN.")
80  if not isinstance(cells_fw, list):
81    raise ValueError("cells_fw must be a list of RNNCells (one per layer).")
82  if not isinstance(cells_bw, list):
83    raise ValueError("cells_bw must be a list of RNNCells (one per layer).")
84  if len(cells_fw) != len(cells_bw):
85    raise ValueError("Forward and Backward cells must have the same depth.")
86  if (initial_states_fw is not None and
87      (not isinstance(initial_states_fw, list) or
88       len(initial_states_fw) != len(cells_fw))):
89    raise ValueError(
90        "initial_states_fw must be a list of state tensors (one per layer).")
91  if (initial_states_bw is not None and
92      (not isinstance(initial_states_bw, list) or
93       len(initial_states_bw) != len(cells_bw))):
94    raise ValueError(
95        "initial_states_bw must be a list of state tensors (one per layer).")
96  states_fw = []
97  states_bw = []
98  prev_layer = inputs
99
100  with vs.variable_scope(scope or "stack_bidirectional_rnn"):
101    for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
102      initial_state_fw = None
103      initial_state_bw = None
104      if initial_states_fw:
105        initial_state_fw = initial_states_fw[i]
106      if initial_states_bw:
107        initial_state_bw = initial_states_bw[i]
108
109      with vs.variable_scope("cell_%d" % i) as cell_scope:
110        prev_layer, state_fw, state_bw = rnn.static_bidirectional_rnn(
111            cell_fw,
112            cell_bw,
113            prev_layer,
114            initial_state_fw=initial_state_fw,
115            initial_state_bw=initial_state_bw,
116            sequence_length=sequence_length,
117            dtype=dtype,
118            scope=cell_scope)
119      states_fw.append(state_fw)
120      states_bw.append(state_bw)
121
122  return prev_layer, tuple(states_fw), tuple(states_bw)
123
124
125def stack_bidirectional_dynamic_rnn(cells_fw,
126                                    cells_bw,
127                                    inputs,
128                                    initial_states_fw=None,
129                                    initial_states_bw=None,
130                                    dtype=None,
131                                    sequence_length=None,
132                                    parallel_iterations=None,
133                                    time_major=False,
134                                    scope=None,
135                                    swap_memory=False):
136  """Creates a dynamic bidirectional recurrent neural network.
137
138  Stacks several bidirectional rnn layers. The combined forward and backward
139  layer outputs are used as input of the next layer. tf.bidirectional_rnn
140  does not allow to share forward and backward information between layers.
141  The input_size of the first forward and backward cells must match.
142  The initial state for both directions is zero and no intermediate states
143  are returned.
144
145  Args:
146    cells_fw: List of instances of RNNCell, one per layer,
147      to be used for forward direction.
148    cells_bw: List of instances of RNNCell, one per layer,
149      to be used for backward direction.
150    inputs: The RNN inputs. this must be a tensor of shape:
151      `[batch_size, max_time, ...]`, or a nested tuple of such elements.
152    initial_states_fw: (optional) A list of the initial states (one per layer)
153      for the forward RNN.
154      Each tensor must has an appropriate type and shape
155      `[batch_size, cell_fw.state_size]`.
156    initial_states_bw: (optional) Same as for `initial_states_fw`, but using
157      the corresponding properties of `cells_bw`.
158    dtype: (optional) The data type for the initial state.  Required if
159      either of the initial states are not provided.
160    sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
161      containing the actual lengths for each of the sequences.
162    parallel_iterations: (Default: 32).  The number of iterations to run in
163      parallel.  Those operations which do not have any temporal dependency
164      and can be run in parallel, will be.  This parameter trades off
165      time for space.  Values >> 1 use more memory but take less time,
166      while smaller values use less memory but computations take longer.
167    time_major: The shape format of the inputs and outputs Tensors. If true,
168      these Tensors must be shaped [max_time, batch_size, depth]. If false,
169      these Tensors must be shaped [batch_size, max_time, depth]. Using
170      time_major = True is a bit more efficient because it avoids transposes at
171      the beginning and end of the RNN calculation. However, most TensorFlow
172      data is batch-major, so by default this function accepts input and emits
173      output in batch-major form.
174    scope: VariableScope for the created subgraph; defaults to None.
175    swap_memory: Transparently swap the tensors produced in forward inference
176      but needed for back prop from GPU to CPU.  This allows training RNNs
177      which would typically not fit on a single GPU, with very minimal (or no)
178      performance penalty.
179
180  Returns:
181    A tuple (outputs, output_state_fw, output_state_bw) where:
182      outputs: Output `Tensor` shaped:
183        `[batch_size, max_time, layers_output]`. Where layers_output
184        are depth-concatenated forward and backward outputs.
185      output_states_fw is the final states, one tensor per layer,
186        of the forward rnn.
187      output_states_bw is the final states, one tensor per layer,
188        of the backward rnn.
189
190  Raises:
191    TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`.
192    ValueError: If inputs is `None`.
193  """
194  if not cells_fw:
195    raise ValueError("Must specify at least one fw cell for BidirectionalRNN.")
196  if not cells_bw:
197    raise ValueError("Must specify at least one bw cell for BidirectionalRNN.")
198  if not isinstance(cells_fw, list):
199    raise ValueError("cells_fw must be a list of RNNCells (one per layer).")
200  if not isinstance(cells_bw, list):
201    raise ValueError("cells_bw must be a list of RNNCells (one per layer).")
202  if len(cells_fw) != len(cells_bw):
203    raise ValueError("Forward and Backward cells must have the same depth.")
204  if (initial_states_fw is not None and
205      (not isinstance(initial_states_fw, list) or
206       len(initial_states_fw) != len(cells_fw))):
207    raise ValueError(
208        "initial_states_fw must be a list of state tensors (one per layer).")
209  if (initial_states_bw is not None and
210      (not isinstance(initial_states_bw, list) or
211       len(initial_states_bw) != len(cells_bw))):
212    raise ValueError(
213        "initial_states_bw must be a list of state tensors (one per layer).")
214
215  states_fw = []
216  states_bw = []
217  prev_layer = inputs
218
219  with vs.variable_scope(scope or "stack_bidirectional_rnn"):
220    for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
221      initial_state_fw = None
222      initial_state_bw = None
223      if initial_states_fw:
224        initial_state_fw = initial_states_fw[i]
225      if initial_states_bw:
226        initial_state_bw = initial_states_bw[i]
227
228      with vs.variable_scope("cell_%d" % i):
229        outputs, (state_fw, state_bw) = rnn.bidirectional_dynamic_rnn(
230            cell_fw,
231            cell_bw,
232            prev_layer,
233            initial_state_fw=initial_state_fw,
234            initial_state_bw=initial_state_bw,
235            sequence_length=sequence_length,
236            parallel_iterations=parallel_iterations,
237            dtype=dtype,
238            swap_memory=swap_memory,
239            time_major=time_major)
240        # Concat the outputs to create the new input.
241        prev_layer = array_ops.concat(outputs, 2)
242      states_fw.append(state_fw)
243      states_bw.append(state_bw)
244
245  return prev_layer, tuple(states_fw), tuple(states_bw)
246