• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A tf.nn.dynamic_rnn variant, built on the Recurrent class.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23
24from tensorflow.contrib.recurrent.python.ops import recurrent
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.util import nest
31
32
33def _GetDTypesFromStructure(struct):
34  dtypes_list = []
35  for x in nest.flatten(struct):
36    x = ops.convert_to_tensor(x)
37    dtypes_list.append(x.dtype)
38  return dtypes_list
39
40
41def _SetShapeFromTemplate(struct, struct_template):
42  as_list = nest.flatten(struct)
43  template_as_list = nest.flatten(struct_template)
44  for element, template in zip(as_list, template_as_list):
45    element.set_shape(template.shape)
46
47
48class _FunctionalRnnCell(object):
49  """Wrapper around RNNCell which separates state from computation.
50
51  This class accomplishes the following:
52  * Turn the cell's `__call__` function into a pure function. The global
53    side effects are separated as `theta`. They are the variables created
54    for the weights of the computation.
55  * Unless the output is aliased as part of the state, extend the state to
56    contain the output so that we store the history in `Recurrent`.
57  * Set static shapes as required.
58  """
59
60  def __init__(self, rnn_cell, seq_inputs, initial_state):
61    assert initial_state is not None
62
63    # TODO(drpng): Dtype needs to be configurable.
64    input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state)
65    # See _index.
66    like_inputs_t = nest.map_structure(
67        lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
68    input_structure = (like_inputs_t, initial_state)
69
70    @function.Defun(*input_dtypes)
71    def FlatCellStep(*flat_inputs):
72      """The flattened version of `rnn_cell`."""
73      inputs_t, state0 = nest.pack_sequence_as(input_structure, flat_inputs)
74      _SetShapeFromTemplate(state0, initial_state)
75      _SetShapeFromTemplate(inputs_t, like_inputs_t)
76      outputs_t, state1 = rnn_cell(inputs_t, state0)
77      state_list = nest.flatten(state1)
78      self._output_shape = outputs_t.shape
79
80      if outputs_t in state_list:
81        output_index_in_state = state_list.index(outputs_t)
82      else:
83        output_index_in_state = None
84
85      if output_index_in_state is None:
86        self._prepend_output = True
87        self._output_state_idx = 0
88        return [outputs_t] + state_list
89      else:
90        self._output_state_idx = output_index_in_state
91        self._prepend_output = False
92        # To save memory, we don't store return the output separately
93        # from the state list, since we know it's the same.
94        return state_list
95
96    def _ToPureFunction(func):
97      # NOTE: This forces the creating of the function.
98      if func.captured_inputs:
99        pure_func = copy.copy(func)
100        # pylint: disable=protected-access
101        pure_func._extra_inputs = []
102        return pure_func
103      return func
104
105    pure_flat_cell_step = _ToPureFunction(FlatCellStep)
106
107    def CellStep(theta, extended_state0, inputs_t):
108      """Performs one time steps on structured inputs.
109
110      The purpose of this function is to turn the parameters into flattened
111      versions, and to resolve the parameter order difference between
112      `Recurrent` and `RNNCell`.
113
114      In the event the cell returns a transformed output that is not aliased
115      within its state, the `extended_state0` also contains the output as its
116      first element.
117
118      Args:
119        theta: Weights required for the computation. A structure of tensors.
120        extended_state0: the state0, and possibly the output at the previous
121          time step. A structure of tensors.
122        inputs_t: the inputs at time t.
123
124      Returns:
125        A pair of the next state (inclusive of the output), and an empty list
126        (unused `extras`).
127        The next state is congruent to state0.
128      """
129      extended_state0_flat = nest.flatten(extended_state0)
130      state0_flat = self.MaybeRemoveOutputFromState(extended_state0_flat)
131      full_inputs = [inputs_t] + state0_flat + theta
132      # Note that the thetas are additional inputs appeneded as extra
133      # parameters.
134      cell_out = pure_flat_cell_step(*full_inputs)
135      return cell_out, []
136
137    self._cell_step = CellStep
138    self._theta = FlatCellStep.captured_inputs
139    self._zero_state = rnn_cell.zero_state
140    self._state_template = initial_state
141    self._output_size = rnn_cell.output_size
142
143  @property
144  def extended_initial_state(self):
145    if self._prepend_output:
146      return [array_ops.zeros(
147          self._output_shape,
148          dtype=_GetDTypesFromStructure(self._state_template)[0]),
149              self._state_template]
150    else:
151      # The base case, where the output is just the hidden state.
152      return self._state_template
153
154  @property
155  def cell_step(self):
156    return self._cell_step
157
158  @property
159  def theta(self):
160    return self._theta
161
162  @property
163  def state_template(self):
164    return self._state_template
165
166  @property
167  def output_shape(self):
168    return self._output_shape
169
170  def GetOutputFromState(self, state):
171    return nest.flatten(state)[self._output_state_idx]
172
173  def MaybeRemoveOutputFromState(self, flat_state):
174    if self._prepend_output:
175      return flat_state[1:]
176    return flat_state
177
178
179def _ApplyLengthsToBatch(sequence_lengths, tf_output):
180  # TODO(drpng): just use Update so that we don't carry over the gradients?
181  """Sets the output to be zero at the end of the sequence."""
182  # output is batch major.
183  shape = array_ops.shape(tf_output)
184  batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
185  output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
186  output_time = array_ops.reshape(output_time, [batch_size, max_time])
187  lengths = array_ops.tile(
188      array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
189  is_less = math_ops.cast(
190      math_ops.less(output_time, lengths), dtype=tf_output.dtype)
191  keep_mask = array_ops.tile(
192      array_ops.expand_dims(is_less, -1),
193      [1, 1, vector_size])
194  final_output = keep_mask * tf_output
195  return final_output
196
197
198def _PickFinalStateFromHistory(acc_state, sequence_length):
199  """Implements acc_state[sequence_length - 1]."""
200  # This will work on all platforms, unlike the regular slice.
201  last_value = []
202  for state_var in nest.flatten(acc_state):
203    # We compute the following with matrix operations:
204    # last_var = state_var[sequence_length - 1]
205    shape = array_ops.shape(state_var)
206    max_time, batch_size = shape[0], shape[1]
207    output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
208    output_time = array_ops.reshape(output_time, [batch_size, max_time])
209    lengths = array_ops.tile(array_ops.reshape(sequence_length,
210                                               [-1, 1]), [1, max_time])
211    last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1),
212                             dtype=state_var.dtype)
213    last_idx = array_ops.transpose(last_idx)
214    last_idx_for_bcast = array_ops.expand_dims(last_idx, -1)
215    sliced = math_ops.multiply(last_idx_for_bcast, state_var)
216    last_var = math_ops.reduce_sum(sliced, 0)
217    last_value += [last_var]
218  return nest.pack_sequence_as(acc_state, last_value)
219
220
221def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
222                       total_time, inputs_lengths, is_reversed):
223  """Post-process output of recurrent.
224
225  This function takes the accumulated extended state and extracts the requested
226  state and output.
227
228  When `inputs_lengths` has been set, it extracts the output from the
229  accumulated state. It also sets outputs past.
230
231  When `is_reversed` is true, the output will be reversed in this function.
232
233  It also sets the static shape information.
234
235  Args:
236    extended_acc_state: A structure containing the accumulated state at each
237      time. It may contain the output at each time as well.
238    extended_final_state: A structure containing the final state. It may
239      contain the output at the final time.
240    func_cell: The functional wrapper around the cell.
241    total_time: A scalar integer tensor.
242    inputs_lengths: An integer tensor with one entry per input.
243    is_reversed: A boolean to indicate if the sequence is reversed.
244
245  Returns:
246    A tuple with the outputs at each time, and the final state.
247  """
248  if inputs_lengths is None or is_reversed:
249    flat_final_state = func_cell.MaybeRemoveOutputFromState(
250        nest.flatten(extended_final_state))
251    tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
252  else:
253    # The accumulated state is over the entire sequence, so we pick it
254    # out from the acc_state sequence.
255    flat_acc_state = func_cell.MaybeRemoveOutputFromState(
256        nest.flatten(extended_acc_state))
257    acc_state = nest.pack_sequence_as(
258        func_cell.state_template, flat_acc_state)
259    tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)
260
261  output_from_state = func_cell.GetOutputFromState(extended_acc_state)
262  if is_reversed:
263    output_from_state = array_ops.reverse(output_from_state, [0])
264  tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
265  tf_output.set_shape(
266      [func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
267  if inputs_lengths is not None:
268    # Need set the outputs to zero.
269    tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
270  _SetShapeFromTemplate(tf_state, func_cell.state_template)
271  return tf_output, tf_state
272
273
274# pylint: disable=invalid-name
275def functional_rnn(cell,
276                   inputs,
277                   sequence_length=None,
278                   initial_state=None,
279                   dtype=None,
280                   time_major=False,
281                   scope=None,
282                   use_tpu=False,
283                   reverse=False):
284  """Same interface as `tf.nn.dynamic_rnn`."""
285  with variable_scope.variable_scope(scope or 'rnn'):
286    if not time_major:
287      inputs = nest.map_structure(
288          lambda t: array_ops.transpose(t, [1, 0, 2]), inputs)
289    inputs_flat = nest.flatten(inputs)
290    batch_size = array_ops.shape(inputs_flat[0])[1]
291    if initial_state is None:
292      initial_state = cell.zero_state(batch_size, dtype)
293    func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
294  if sequence_length is not None:
295    max_length = math_ops.reduce_max(sequence_length)
296  else:
297    max_length = None
298  if reverse:
299    inputs = array_ops.reverse(inputs, [0])
300  extended_acc_state, extended_final_state = recurrent.Recurrent(
301      theta=func_cell.theta,
302      state0=func_cell.extended_initial_state,
303      inputs=inputs,
304      cell_fn=func_cell.cell_step,
305      max_input_length=max_length,
306      use_tpu=use_tpu,
307      aligned_end=reverse)
308
309  tf_output, tf_state = _PostProcessOutput(
310      extended_acc_state,
311      extended_final_state,
312      func_cell,
313      inputs_flat[0].shape[0],
314      sequence_length,
315      is_reversed=reverse)
316
317  if time_major:
318    tf_output = array_ops.transpose(tf_output, [1, 0, 2])
319  return tf_output, tf_state
320
321
322def bidirectional_functional_rnn(cell_fw,
323                                 cell_bw,
324                                 inputs,
325                                 initial_state_fw=None,
326                                 initial_state_bw=None,
327                                 dtype=None,
328                                 sequence_length=None,
329                                 time_major=False,
330                                 use_tpu=False,
331                                 fast_reverse=False,
332                                 scope=None):
333  """Creates a bidirectional recurrent neural network.
334
335  Performs fully dynamic unrolling of inputs in both directions. Built to be API
336  compatible with `tf.nn.bidirectional_dynamic_rnn`, but implemented with
337  functional control flow for TPU compatibility.
338
339  Args:
340    cell_fw: An instance of `tf.contrib.rnn.RNNCell`.
341    cell_bw: An instance of `tf.contrib.rnn.RNNCell`.
342    inputs: The RNN inputs. If time_major == False (default), this must be a
343      Tensor (or hierarchical structure of Tensors) of shape
344      [batch_size, max_time, ...]. If time_major == True, this must be a Tensor
345      (or hierarchical structure of Tensors) of shape:
346      [max_time, batch_size, ...]. The first two dimensions must match across
347      all the inputs, but otherwise the ranks and other shape components may
348      differ.
349    initial_state_fw: An optional initial state for `cell_fw`. Should match
350      `cell_fw.zero_state` in structure and type.
351    initial_state_bw: An optional initial state for `cell_bw`. Should match
352      `cell_bw.zero_state` in structure and type.
353    dtype: (optional) The data type for the initial state and expected output.
354      Required if initial_states are not provided or RNN state has a
355      heterogeneous dtype.
356    sequence_length: An optional int32/int64 vector sized [batch_size]. Used to
357      copy-through state and zero-out outputs when past a batch element's
358      sequence length. So it's more for correctness than performance.
359    time_major: Whether the `inputs` tensor is in "time major" format.
360    use_tpu: Whether to enable TPU-compatible operation. If True, does not truly
361      reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can
362      remove this flag.
363    fast_reverse: Whether to use fast tf.reverse to replace tf.reverse_sequence.
364      This is only possible when either all sequence lengths are the same inside
365      the batch, or when the cell function does not change the state on padded
366      input.
367    scope: An optional scope name for the dynamic RNN.
368
369  Returns:
370    outputs: A tuple of `(output_fw, output_bw)`. The output of the forward and
371      backward RNN. If time_major == False (default), these will
372      be Tensors shaped: [batch_size, max_time, cell.output_size]. If
373      time_major == True, these will be Tensors shaped:
374      [max_time, batch_size, cell.output_size]. Note, if cell.output_size is a
375      (possibly nested) tuple of integers or TensorShape objects, then the
376      output for that direction will be a tuple having the same structure as
377      cell.output_size, containing Tensors having shapes corresponding to the
378      shape data in cell.output_size.
379    final_states: A tuple of `(final_state_fw, final_state_bw)`. A Tensor or
380      hierarchical structure of Tensors indicating the final cell state in each
381      direction. Must have the same structure and shape as cell.zero_state.
382
383  Raises:
384    ValueError: If `initial_state_fw` is None or `initial_state_bw` is None and
385      `dtype` is not provided.
386  """
387  # Keep this code in sync with tf.nn.dynamic_rnn for compatibility.
388  with variable_scope.variable_scope(scope or 'bidirectional_rnn'):
389    # Forward direction
390    with variable_scope.variable_scope('fw') as fw_scope:
391      output_fw, output_state_fw = functional_rnn(
392          cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
393          initial_state=initial_state_fw, dtype=dtype,
394          time_major=time_major, scope=fw_scope, use_tpu=use_tpu)
395    # Backward direction
396    if not time_major:
397      time_dim = 1
398      batch_dim = 0
399    else:
400      time_dim = 0
401      batch_dim = 1
402
403    def _reverse(input_, seq_lengths, seq_dim, batch_dim):
404      if seq_lengths is not None:
405        return array_ops.reverse_sequence(
406            input=input_, seq_lengths=seq_lengths,
407            seq_dim=seq_dim, batch_dim=batch_dim)
408      else:
409        # See b/69305369.
410        assert not use_tpu, (
411            'Bidirectional with variable sequence lengths unsupported on TPU')
412        return array_ops.reverse(input_, axis=[seq_dim])
413
414    with variable_scope.variable_scope('bw') as bw_scope:
415      if not fast_reverse:
416        inputs = _reverse(
417            inputs,
418            seq_lengths=sequence_length,
419            seq_dim=time_dim,
420            batch_dim=batch_dim)
421      output_bw, output_state_bw = functional_rnn(
422          cell=cell_bw,
423          inputs=inputs,
424          sequence_length=sequence_length,
425          initial_state=initial_state_bw,
426          dtype=dtype,
427          time_major=time_major,
428          scope=bw_scope,
429          use_tpu=use_tpu,
430          reverse=fast_reverse)
431
432  if not fast_reverse:
433    output_bw = _reverse(
434        output_bw,
435        seq_lengths=sequence_length,
436        seq_dim=time_dim,
437        batch_dim=batch_dim)
438
439  outputs = (output_fw, output_bw)
440  output_states = (output_state_fw, output_state_bw)
441
442  return (outputs, output_states)
443# pylint: enable=invalid-name
444