• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Seq2seq layer operations for use in neural networks."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import six
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.keras import layers
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import rnn
36from tensorflow.python.ops import rnn_cell_impl
37from tensorflow.python.ops import tensor_array_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.util import nest
40
41
42__all__ = ["Decoder", "dynamic_decode"]
43
44
45_transpose_batch_time = rnn._transpose_batch_time  # pylint: disable=protected-access
46_zero_state_tensors = rnn_cell_impl._zero_state_tensors  # pylint: disable=protected-access
47
48
49@six.add_metaclass(abc.ABCMeta)
50class Decoder(object):
51  """An RNN Decoder abstract interface object.
52
53  Concepts used by this interface:
54  - `inputs`: (structure of) tensors and TensorArrays that is passed as input to
55    the RNNCell composing the decoder, at each time step.
56  - `state`: (structure of) tensors and TensorArrays that is passed to the
57    RNNCell instance as the state.
58  - `finished`: boolean tensor telling whether each sequence in the batch is
59    finished.
60  - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each
61    time step.
62  """
63
64  @property
65  def batch_size(self):
66    """The batch size of input values."""
67    raise NotImplementedError
68
69  @property
70  def output_size(self):
71    """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s]."""
72    raise NotImplementedError
73
74  @property
75  def output_dtype(self):
76    """A (possibly nested tuple of...) dtype[s]."""
77    raise NotImplementedError
78
79  @abc.abstractmethod
80  def initialize(self, name=None):
81    """Called before any decoding iterations.
82
83    This methods must compute initial input values and initial state.
84
85    Args:
86      name: Name scope for any created operations.
87
88    Returns:
89      `(finished, initial_inputs, initial_state)`: initial values of
90      'finished' flags, inputs and state.
91    """
92    raise NotImplementedError
93
94  @abc.abstractmethod
95  def step(self, time, inputs, state, name=None):
96    """Called per step of decoding (but only once for dynamic decoding).
97
98    Args:
99      time: Scalar `int32` tensor. Current step number.
100      inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time
101        step.
102      state: RNNCell state (possibly nested tuple of) tensor[s] from previous
103        time step.
104      name: Name scope for any created operations.
105
106    Returns:
107      `(outputs, next_state, next_inputs, finished)`: `outputs` is an object
108      containing the decoder output, `next_state` is a (structure of) state
109      tensors and TensorArrays, `next_inputs` is the tensor that should be used
110      as input for the next step, `finished` is a boolean tensor telling whether
111      the sequence is complete, for each sequence in the batch.
112    """
113    raise NotImplementedError
114
115  def finalize(self, outputs, final_state, sequence_lengths):
116    raise NotImplementedError
117
118  @property
119  def tracks_own_finished(self):
120    """Describes whether the Decoder keeps track of finished states.
121
122    Most decoders will emit a true/false `finished` value independently
123    at each time step.  In this case, the `dynamic_decode` function keeps track
124    of which batch entries are already finished, and performs a logical OR to
125    insert new batches to the finished set.
126
127    Some decoders, however, shuffle batches / beams between time steps and
128    `dynamic_decode` will mix up the finished state across these entries because
129    it does not track the reshuffle across time steps.  In this case, it is
130    up to the decoder to declare that it will keep track of its own finished
131    state by setting this property to `True`.
132
133    Returns:
134      Python bool.
135    """
136    return False
137
138
139class BaseDecoder(layers.Layer):
140  """An RNN Decoder that is based on a Keras layer.
141
142  Concepts used by this interface:
143  - `inputs`: (structure of) tensors and TensorArrays that is passed as input to
144    the RNNCell composing the decoder, at each time step.
145  - `state`: (structure of) tensors and TensorArrays that is passed to the
146    RNNCell instance as the state.
147  - `memory`: (sturecute of) tensors that is usually the full output of the
148    encoder, which will be used for the attention wrapper for the RNNCell.
149  - `finished`: boolean tensor telling whether each sequence in the batch is
150    finished.
151  - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each
152    time step.
153  """
154
155  def __init__(self,
156               output_time_major=False,
157               impute_finished=False,
158               maximum_iterations=None,
159               parallel_iterations=32,
160               swap_memory=False,
161               **kwargs):
162    self.output_time_major = output_time_major
163    self.impute_finished = impute_finished
164    self.maximum_iterations = maximum_iterations
165    self.parallel_iterations = parallel_iterations
166    self.swap_memory = swap_memory
167    super(BaseDecoder, self).__init__(**kwargs)
168
169  def call(self, inputs, initial_state=None, **kwargs):
170    init_kwargs = kwargs
171    init_kwargs["initial_state"] = initial_state
172    return dynamic_decode(self,
173                          output_time_major=self.output_time_major,
174                          impute_finished=self.impute_finished,
175                          maximum_iterations=self.maximum_iterations,
176                          parallel_iterations=self.parallel_iterations,
177                          swap_memory=self.swap_memory,
178                          decoder_init_input=inputs,
179                          decoder_init_kwargs=init_kwargs)
180
181  @property
182  def batch_size(self):
183    """The batch size of input values."""
184    raise NotImplementedError
185
186  @property
187  def output_size(self):
188    """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s]."""
189    raise NotImplementedError
190
191  @property
192  def output_dtype(self):
193    """A (possibly nested tuple of...) dtype[s]."""
194    raise NotImplementedError
195
196  def initialize(self, inputs, initial_state=None, **kwargs):
197    """Called before any decoding iterations.
198
199    This methods must compute initial input values and initial state.
200
201    Args:
202      inputs: (structure of) tensors that contains the input for the decoder. In
203        the normal case, its a tensor with shape [batch, timestep, embedding].
204      initial_state: (structure of) tensors that contains the initial state for
205        the RNNCell.
206      **kwargs: Other arguments that are passed in from layer.call() method. It
207        could contains item like input sequence_length, or masking for input.
208
209    Returns:
210      `(finished, initial_inputs, initial_state)`: initial values of
211      'finished' flags, inputs and state.
212    """
213    raise NotImplementedError
214
215  def step(self, time, inputs, state):
216    """Called per step of decoding (but only once for dynamic decoding).
217
218    Args:
219      time: Scalar `int32` tensor. Current step number.
220      inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time
221        step.
222      state: RNNCell state (possibly nested tuple of) tensor[s] from previous
223        time step.
224
225    Returns:
226      `(outputs, next_state, next_inputs, finished)`: `outputs` is an object
227      containing the decoder output, `next_state` is a (structure of) state
228      tensors and TensorArrays, `next_inputs` is the tensor that should be used
229      as input for the next step, `finished` is a boolean tensor telling whether
230      the sequence is complete, for each sequence in the batch.
231    """
232    raise NotImplementedError
233
234  def finalize(self, outputs, final_state, sequence_lengths):
235    raise NotImplementedError
236
237  @property
238  def tracks_own_finished(self):
239    """Describes whether the Decoder keeps track of finished states.
240
241    Most decoders will emit a true/false `finished` value independently
242    at each time step.  In this case, the `dynamic_decode` function keeps track
243    of which batch entries are already finished, and performs a logical OR to
244    insert new batches to the finished set.
245
246    Some decoders, however, shuffle batches / beams between time steps and
247    `dynamic_decode` will mix up the finished state across these entries because
248    it does not track the reshuffle across time steps.  In this case, it is
249    up to the decoder to declare that it will keep track of its own finished
250    state by setting this property to `True`.
251
252    Returns:
253      Python bool.
254    """
255    return False
256
257  # TODO(scottzhu): Add build/get_config/from_config and other layer methods.
258
259
260def _create_zero_outputs(size, dtype, batch_size):
261  """Create a zero outputs Tensor structure."""
262  def _create(s, d):
263    return _zero_state_tensors(s, batch_size, d)
264
265  return nest.map_structure(_create, size, dtype)
266
267
268def dynamic_decode(decoder,
269                   output_time_major=False,
270                   impute_finished=False,
271                   maximum_iterations=None,
272                   parallel_iterations=32,
273                   swap_memory=False,
274                   scope=None,
275                   **kwargs):
276  """Perform dynamic decoding with `decoder`.
277
278  Calls initialize() once and step() repeatedly on the Decoder object.
279
280  Args:
281    decoder: A `Decoder` instance.
282    output_time_major: Python boolean.  Default: `False` (batch major).  If
283      `True`, outputs are returned as time major tensors (this mode is faster).
284      Otherwise, outputs are returned as batch major tensors (this adds extra
285      time to the computation).
286    impute_finished: Python boolean.  If `True`, then states for batch
287      entries which are marked as finished get copied through and the
288      corresponding outputs get zeroed out.  This causes some slowdown at
289      each time step, but ensures that the final state and outputs have
290      the correct values and that backprop ignores time steps that were
291      marked as finished.
292    maximum_iterations: `int32` scalar, maximum allowed number of decoding
293       steps.  Default is `None` (decode until the decoder is fully done).
294    parallel_iterations: Argument passed to `tf.while_loop`.
295    swap_memory: Argument passed to `tf.while_loop`.
296    scope: Optional variable scope to use.
297    **kwargs: dict, other keyword arguments for dynamic_decode. It might contain
298      arguments for `BaseDecoder` to initialize, which takes all tensor inputs
299      during call().
300
301  Returns:
302    `(final_outputs, final_state, final_sequence_lengths)`.
303
304  Raises:
305    TypeError: if `decoder` is not an instance of `Decoder`.
306    ValueError: if `maximum_iterations` is provided but is not a scalar.
307  """
308  if not isinstance(decoder, (Decoder, BaseDecoder)):
309    raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
310                    type(decoder))
311
312  with variable_scope.variable_scope(scope, "decoder") as varscope:
313    # Determine context types.
314    ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
315    is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
316    in_while_loop = (
317        control_flow_util.GetContainingWhileContext(ctxt) is not None)
318    # Properly cache variable values inside the while_loop.
319    # Don't set a caching device when running in a loop, since it is possible
320    # that train steps could be wrapped in a tf.while_loop. In that scenario
321    # caching prevents forward computations in loop iterations from re-reading
322    # the updated weights.
323    if not context.executing_eagerly() and not in_while_loop:
324      if varscope.caching_device is None:
325        varscope.set_caching_device(lambda op: op.device)
326
327    if maximum_iterations is not None:
328      maximum_iterations = ops.convert_to_tensor(
329          maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
330      if maximum_iterations.get_shape().ndims != 0:
331        raise ValueError("maximum_iterations must be a scalar")
332
333    if isinstance(decoder, Decoder):
334      initial_finished, initial_inputs, initial_state = decoder.initialize()
335    else:
336      # For BaseDecoder that takes tensor inputs during call.
337      decoder_init_input = kwargs.pop("decoder_init_input", None)
338      decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})
339      initial_finished, initial_inputs, initial_state = decoder.initialize(
340          decoder_init_input, **decoder_init_kwargs)
341
342    zero_outputs = _create_zero_outputs(decoder.output_size,
343                                        decoder.output_dtype,
344                                        decoder.batch_size)
345
346    if is_xla and maximum_iterations is None:
347      raise ValueError("maximum_iterations is required for XLA compilation.")
348    if maximum_iterations is not None:
349      initial_finished = math_ops.logical_or(
350          initial_finished, 0 >= maximum_iterations)
351    initial_sequence_lengths = array_ops.zeros_like(
352        initial_finished, dtype=dtypes.int32)
353    initial_time = constant_op.constant(0, dtype=dtypes.int32)
354
355    def _shape(batch_size, from_shape):
356      if (not isinstance(from_shape, tensor_shape.TensorShape) or
357          from_shape.ndims == 0):
358        return None
359      else:
360        batch_size = tensor_util.constant_value(
361            ops.convert_to_tensor(
362                batch_size, name="batch_size"))
363        return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
364
365    dynamic_size = maximum_iterations is None or not is_xla
366
367    def _create_ta(s, d):
368      return tensor_array_ops.TensorArray(
369          dtype=d,
370          size=0 if dynamic_size else maximum_iterations,
371          dynamic_size=dynamic_size,
372          element_shape=_shape(decoder.batch_size, s))
373
374    initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
375                                            decoder.output_dtype)
376
377    def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
378                  finished, unused_sequence_lengths):
379      return math_ops.logical_not(math_ops.reduce_all(finished))
380
381    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
382      """Internal while_loop body.
383
384      Args:
385        time: scalar int32 tensor.
386        outputs_ta: structure of TensorArray.
387        state: (structure of) state tensors and TensorArrays.
388        inputs: (structure of) input tensors.
389        finished: bool tensor (keeping track of what's finished).
390        sequence_lengths: int32 tensor (keeping track of time of finish).
391
392      Returns:
393        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
394          next_sequence_lengths)`.
395        ```
396      """
397      (next_outputs, decoder_state, next_inputs,
398       decoder_finished) = decoder.step(time, inputs, state)
399      if decoder.tracks_own_finished:
400        next_finished = decoder_finished
401      else:
402        next_finished = math_ops.logical_or(decoder_finished, finished)
403      next_sequence_lengths = array_ops.where(
404          math_ops.logical_not(finished),
405          array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
406          sequence_lengths)
407
408      nest.assert_same_structure(state, decoder_state)
409      nest.assert_same_structure(outputs_ta, next_outputs)
410      nest.assert_same_structure(inputs, next_inputs)
411
412      # Zero out output values past finish
413      if impute_finished:
414        emit = nest.map_structure(
415            lambda out, zero: array_ops.where(finished, zero, out),
416            next_outputs,
417            zero_outputs)
418      else:
419        emit = next_outputs
420
421      # Copy through states past finish
422      def _maybe_copy_state(new, cur):
423        # TensorArrays and scalar states get passed through.
424        if isinstance(cur, tensor_array_ops.TensorArray):
425          pass_through = True
426        else:
427          new.set_shape(cur.shape)
428          pass_through = (new.shape.ndims == 0)
429        return new if pass_through else array_ops.where(finished, cur, new)
430
431      if impute_finished:
432        next_state = nest.map_structure(
433            _maybe_copy_state, decoder_state, state)
434      else:
435        next_state = decoder_state
436
437      outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
438                                      outputs_ta, emit)
439      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
440              next_sequence_lengths)
441
442    res = control_flow_ops.while_loop(
443        condition,
444        body,
445        loop_vars=(
446            initial_time,
447            initial_outputs_ta,
448            initial_state,
449            initial_inputs,
450            initial_finished,
451            initial_sequence_lengths,
452        ),
453        parallel_iterations=parallel_iterations,
454        maximum_iterations=maximum_iterations,
455        swap_memory=swap_memory)
456
457    final_outputs_ta = res[1]
458    final_state = res[2]
459    final_sequence_lengths = res[5]
460
461    final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
462
463    try:
464      final_outputs, final_state = decoder.finalize(
465          final_outputs, final_state, final_sequence_lengths)
466    except NotImplementedError:
467      pass
468
469    if not output_time_major:
470      final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
471
472  return final_outputs, final_state, final_sequence_lengths
473