• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for creating sequence-to-sequence models in TensorFlow.
16
17Sequence-to-sequence recurrent neural networks can learn complex functions
18that map input sequences to output sequences. These models yield very good
19results on a number of tasks, such as speech recognition, parsing, machine
20translation, or even constructing automated replies to emails.
21
22Before using this module, it is recommended to read the TensorFlow tutorial
23on sequence-to-sequence models. It explains the basic concepts of this module
24and shows an end-to-end example of how to build a translation model.
25  https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html
26
27Here is an overview of functions available in this module. They all use
28a very similar interface, so after reading the above tutorial and using
29one of them, others should be easy to substitute.
30
31* Full sequence-to-sequence models.
32  - basic_rnn_seq2seq: The most basic RNN-RNN model.
33  - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights.
34  - embedding_rnn_seq2seq: The basic model with input embedding.
35  - embedding_tied_rnn_seq2seq: The tied model with input embedding.
36  - embedding_attention_seq2seq: Advanced model with input embedding and
37      the neural attention mechanism; recommended for complex tasks.
38
39* Multi-task sequence-to-sequence models.
40  - one2many_rnn_seq2seq: The embedding model with multiple decoders.
41
42* Decoders (when you write your own encoder, you can use these to decode;
43    e.g., if you want to write a model that generates captions for images).
44  - rnn_decoder: The basic decoder based on a pure RNN.
45  - attention_decoder: A decoder that uses the attention mechanism.
46
47* Losses.
48  - sequence_loss: Loss for a sequence model returning average log-perplexity.
49  - sequence_loss_by_example: As above, but not averaging over all examples.
50
51* model_with_buckets: A convenience function to create models with bucketing
52    (see the tutorial above for an explanation of why and how to use it).
53"""
54
55from __future__ import absolute_import
56from __future__ import division
57from __future__ import print_function
58
59import copy
60
61# We disable pylint because we need python3 compatibility.
62from six.moves import xrange  # pylint: disable=redefined-builtin
63from six.moves import zip  # pylint: disable=redefined-builtin
64
65from tensorflow.contrib.rnn.python.ops import core_rnn_cell
66from tensorflow.python.framework import dtypes
67from tensorflow.python.framework import ops
68from tensorflow.python.ops import array_ops
69from tensorflow.python.ops import control_flow_ops
70from tensorflow.python.ops import embedding_ops
71from tensorflow.python.ops import math_ops
72from tensorflow.python.ops import nn_ops
73from tensorflow.python.ops import rnn
74from tensorflow.python.ops import rnn_cell_impl
75from tensorflow.python.ops import variable_scope
76from tensorflow.python.util import nest
77
78# TODO(ebrevdo): Remove once _linear is fully deprecated.
79Linear = core_rnn_cell._Linear  # pylint: disable=protected-access,invalid-name
80
81
82def _extract_argmax_and_embed(embedding,
83                              output_projection=None,
84                              update_embedding=True):
85  """Get a loop_function that extracts the previous symbol and embeds it.
86
87  Args:
88    embedding: embedding tensor for symbols.
89    output_projection: None or a pair (W, B). If provided, each fed previous
90      output will first be multiplied by W and added B.
91    update_embedding: Boolean; if False, the gradients will not propagate
92      through the embeddings.
93
94  Returns:
95    A loop function.
96  """
97
98  def loop_function(prev, _):
99    if output_projection is not None:
100      prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
101    prev_symbol = math_ops.argmax(prev, 1)
102    # Note that gradients will not propagate through the second parameter of
103    # embedding_lookup.
104    emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
105    if not update_embedding:
106      emb_prev = array_ops.stop_gradient(emb_prev)
107    return emb_prev
108
109  return loop_function
110
111
112def rnn_decoder(decoder_inputs,
113                initial_state,
114                cell,
115                loop_function=None,
116                scope=None):
117  """RNN decoder for the sequence-to-sequence model.
118
119  Args:
120    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
121    initial_state: 2D Tensor with shape [batch_size x cell.state_size].
122    cell: rnn_cell.RNNCell defining the cell function and size.
123    loop_function: If not None, this function will be applied to the i-th output
124      in order to generate the i+1-st input, and decoder_inputs will be ignored,
125      except for the first element ("GO" symbol). This can be used for decoding,
126      but also for training to emulate http://arxiv.org/abs/1506.03099.
127      Signature -- loop_function(prev, i) = next
128        * prev is a 2D Tensor of shape [batch_size x output_size],
129        * i is an integer, the step number (when advanced control is needed),
130        * next is a 2D Tensor of shape [batch_size x input_size].
131    scope: VariableScope for the created subgraph; defaults to "rnn_decoder".
132
133  Returns:
134    A tuple of the form (outputs, state), where:
135      outputs: A list of the same length as decoder_inputs of 2D Tensors with
136        shape [batch_size x output_size] containing generated outputs.
137      state: The state of each cell at the final time-step.
138        It is a 2D Tensor of shape [batch_size x cell.state_size].
139        (Note that in some cases, like basic RNN cell or GRU cell, outputs and
140         states can be the same. They are different for LSTM cells though.)
141  """
142  with variable_scope.variable_scope(scope or "rnn_decoder"):
143    state = initial_state
144    outputs = []
145    prev = None
146    for i, inp in enumerate(decoder_inputs):
147      if loop_function is not None and prev is not None:
148        with variable_scope.variable_scope("loop_function", reuse=True):
149          inp = loop_function(prev, i)
150      if i > 0:
151        variable_scope.get_variable_scope().reuse_variables()
152      output, state = cell(inp, state)
153      outputs.append(output)
154      if loop_function is not None:
155        prev = output
156  return outputs, state
157
158
159def basic_rnn_seq2seq(encoder_inputs,
160                      decoder_inputs,
161                      cell,
162                      dtype=dtypes.float32,
163                      scope=None):
164  """Basic RNN sequence-to-sequence model.
165
166  This model first runs an RNN to encode encoder_inputs into a state vector,
167  then runs decoder, initialized with the last encoder state, on decoder_inputs.
168  Encoder and decoder use the same RNN cell type, but don't share parameters.
169
170  Args:
171    encoder_inputs: A list of 2D Tensors [batch_size x input_size].
172    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
173    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
174    dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
175    scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
176
177  Returns:
178    A tuple of the form (outputs, state), where:
179      outputs: A list of the same length as decoder_inputs of 2D Tensors with
180        shape [batch_size x output_size] containing the generated outputs.
181      state: The state of each decoder cell in the final time-step.
182        It is a 2D Tensor of shape [batch_size x cell.state_size].
183  """
184  with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
185    enc_cell = copy.deepcopy(cell)
186    _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
187    return rnn_decoder(decoder_inputs, enc_state, cell)
188
189
190def tied_rnn_seq2seq(encoder_inputs,
191                     decoder_inputs,
192                     cell,
193                     loop_function=None,
194                     dtype=dtypes.float32,
195                     scope=None):
196  """RNN sequence-to-sequence model with tied encoder and decoder parameters.
197
198  This model first runs an RNN to encode encoder_inputs into a state vector, and
199  then runs decoder, initialized with the last encoder state, on decoder_inputs.
200  Encoder and decoder use the same RNN cell and share parameters.
201
202  Args:
203    encoder_inputs: A list of 2D Tensors [batch_size x input_size].
204    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
205    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
206    loop_function: If not None, this function will be applied to i-th output
207      in order to generate i+1-th input, and decoder_inputs will be ignored,
208      except for the first element ("GO" symbol), see rnn_decoder for details.
209    dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
210    scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".
211
212  Returns:
213    A tuple of the form (outputs, state), where:
214      outputs: A list of the same length as decoder_inputs of 2D Tensors with
215        shape [batch_size x output_size] containing the generated outputs.
216      state: The state of each decoder cell in each time-step. This is a list
217        with length len(decoder_inputs) -- one item for each time-step.
218        It is a 2D Tensor of shape [batch_size x cell.state_size].
219  """
220  with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
221    scope = scope or "tied_rnn_seq2seq"
222    _, enc_state = rnn.static_rnn(
223        cell, encoder_inputs, dtype=dtype, scope=scope)
224    variable_scope.get_variable_scope().reuse_variables()
225    return rnn_decoder(
226        decoder_inputs,
227        enc_state,
228        cell,
229        loop_function=loop_function,
230        scope=scope)
231
232
233def embedding_rnn_decoder(decoder_inputs,
234                          initial_state,
235                          cell,
236                          num_symbols,
237                          embedding_size,
238                          output_projection=None,
239                          feed_previous=False,
240                          update_embedding_for_previous=True,
241                          scope=None):
242  """RNN decoder with embedding and a pure-decoding option.
243
244  Args:
245    decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
246    initial_state: 2D Tensor [batch_size x cell.state_size].
247    cell: tf.nn.rnn_cell.RNNCell defining the cell function.
248    num_symbols: Integer, how many symbols come into the embedding.
249    embedding_size: Integer, the length of the embedding vector for each symbol.
250    output_projection: None or a pair (W, B) of output projection weights and
251      biases; W has shape [output_size x num_symbols] and B has
252      shape [num_symbols]; if provided and feed_previous=True, each fed
253      previous output will first be multiplied by W and added B.
254    feed_previous: Boolean; if True, only the first of decoder_inputs will be
255      used (the "GO" symbol), and all other decoder inputs will be generated by:
256        next = embedding_lookup(embedding, argmax(previous_output)),
257      In effect, this implements a greedy decoder. It can also be used
258      during training to emulate http://arxiv.org/abs/1506.03099.
259      If False, decoder_inputs are used as given (the standard decoder case).
260    update_embedding_for_previous: Boolean; if False and feed_previous=True,
261      only the embedding for the first symbol of decoder_inputs (the "GO"
262      symbol) will be updated by back propagation. Embeddings for the symbols
263      generated from the decoder itself remain unchanged. This parameter has
264      no effect if feed_previous=False.
265    scope: VariableScope for the created subgraph; defaults to
266      "embedding_rnn_decoder".
267
268  Returns:
269    A tuple of the form (outputs, state), where:
270      outputs: A list of the same length as decoder_inputs of 2D Tensors. The
271        output is of shape [batch_size x cell.output_size] when
272        output_projection is not None (and represents the dense representation
273        of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
274        when output_projection is None.
275      state: The state of each decoder cell in each time-step. This is a list
276        with length len(decoder_inputs) -- one item for each time-step.
277        It is a 2D Tensor of shape [batch_size x cell.state_size].
278
279  Raises:
280    ValueError: When output_projection has the wrong shape.
281  """
282  with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope:
283    if output_projection is not None:
284      dtype = scope.dtype
285      proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
286      proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
287      proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
288      proj_biases.get_shape().assert_is_compatible_with([num_symbols])
289
290    embedding = variable_scope.get_variable("embedding",
291                                            [num_symbols, embedding_size])
292    loop_function = _extract_argmax_and_embed(
293        embedding, output_projection,
294        update_embedding_for_previous) if feed_previous else None
295    emb_inp = (embedding_ops.embedding_lookup(embedding, i)
296               for i in decoder_inputs)
297    return rnn_decoder(
298        emb_inp, initial_state, cell, loop_function=loop_function)
299
300
301def embedding_rnn_seq2seq(encoder_inputs,
302                          decoder_inputs,
303                          cell,
304                          num_encoder_symbols,
305                          num_decoder_symbols,
306                          embedding_size,
307                          output_projection=None,
308                          feed_previous=False,
309                          dtype=None,
310                          scope=None):
311  """Embedding RNN sequence-to-sequence model.
312
313  This model first embeds encoder_inputs by a newly created embedding (of shape
314  [num_encoder_symbols x input_size]). Then it runs an RNN to encode
315  embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs
316  by another newly created embedding (of shape [num_decoder_symbols x
317  input_size]). Then it runs RNN decoder, initialized with the last
318  encoder state, on embedded decoder_inputs.
319
320  Args:
321    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
322    decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
323    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
324    num_encoder_symbols: Integer; number of symbols on the encoder side.
325    num_decoder_symbols: Integer; number of symbols on the decoder side.
326    embedding_size: Integer, the length of the embedding vector for each symbol.
327    output_projection: None or a pair (W, B) of output projection weights and
328      biases; W has shape [output_size x num_decoder_symbols] and B has
329      shape [num_decoder_symbols]; if provided and feed_previous=True, each
330      fed previous output will first be multiplied by W and added B.
331    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
332      of decoder_inputs will be used (the "GO" symbol), and all other decoder
333      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
334      If False, decoder_inputs are used as given (the standard decoder case).
335    dtype: The dtype of the initial state for both the encoder and encoder
336      rnn cells (default: tf.float32).
337    scope: VariableScope for the created subgraph; defaults to
338      "embedding_rnn_seq2seq"
339
340  Returns:
341    A tuple of the form (outputs, state), where:
342      outputs: A list of the same length as decoder_inputs of 2D Tensors. The
343        output is of shape [batch_size x cell.output_size] when
344        output_projection is not None (and represents the dense representation
345        of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
346        when output_projection is None.
347      state: The state of each decoder cell in each time-step. This is a list
348        with length len(decoder_inputs) -- one item for each time-step.
349        It is a 2D Tensor of shape [batch_size x cell.state_size].
350  """
351  with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
352    if dtype is not None:
353      scope.set_dtype(dtype)
354    else:
355      dtype = scope.dtype
356
357    # Encoder.
358    encoder_cell = copy.deepcopy(cell)
359    encoder_cell = core_rnn_cell.EmbeddingWrapper(
360        encoder_cell,
361        embedding_classes=num_encoder_symbols,
362        embedding_size=embedding_size)
363    _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
364
365    # Decoder.
366    if output_projection is None:
367      cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
368
369    if isinstance(feed_previous, bool):
370      return embedding_rnn_decoder(
371          decoder_inputs,
372          encoder_state,
373          cell,
374          num_decoder_symbols,
375          embedding_size,
376          output_projection=output_projection,
377          feed_previous=feed_previous)
378
379    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
380    def decoder(feed_previous_bool):
381      reuse = None if feed_previous_bool else True
382      with variable_scope.variable_scope(
383          variable_scope.get_variable_scope(), reuse=reuse):
384        outputs, state = embedding_rnn_decoder(
385            decoder_inputs,
386            encoder_state,
387            cell,
388            num_decoder_symbols,
389            embedding_size,
390            output_projection=output_projection,
391            feed_previous=feed_previous_bool,
392            update_embedding_for_previous=False)
393        state_list = [state]
394        if nest.is_sequence(state):
395          state_list = nest.flatten(state)
396        return outputs + state_list
397
398    outputs_and_state = control_flow_ops.cond(feed_previous,
399                                              lambda: decoder(True),
400                                              lambda: decoder(False))
401    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
402    state_list = outputs_and_state[outputs_len:]
403    state = state_list[0]
404    if nest.is_sequence(encoder_state):
405      state = nest.pack_sequence_as(
406          structure=encoder_state, flat_sequence=state_list)
407    return outputs_and_state[:outputs_len], state
408
409
410def embedding_tied_rnn_seq2seq(encoder_inputs,
411                               decoder_inputs,
412                               cell,
413                               num_symbols,
414                               embedding_size,
415                               num_decoder_symbols=None,
416                               output_projection=None,
417                               feed_previous=False,
418                               dtype=None,
419                               scope=None):
420  """Embedding RNN sequence-to-sequence model with tied (shared) parameters.
421
422  This model first embeds encoder_inputs by a newly created embedding (of shape
423  [num_symbols x input_size]). Then it runs an RNN to encode embedded
424  encoder_inputs into a state vector. Next, it embeds decoder_inputs using
425  the same embedding. Then it runs RNN decoder, initialized with the last
426  encoder state, on embedded decoder_inputs. The decoder output is over symbols
427  from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it
428  is over 0 to num_symbols - 1.
429
430  Args:
431    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
432    decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
433    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
434    num_symbols: Integer; number of symbols for both encoder and decoder.
435    embedding_size: Integer, the length of the embedding vector for each symbol.
436    num_decoder_symbols: Integer; number of output symbols for decoder. If
437      provided, the decoder output is over symbols 0 to num_decoder_symbols - 1.
438      Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that
439      this assumes that the vocabulary is set up such that the first
440      num_decoder_symbols of num_symbols are part of decoding.
441    output_projection: None or a pair (W, B) of output projection weights and
442      biases; W has shape [output_size x num_symbols] and B has
443      shape [num_symbols]; if provided and feed_previous=True, each
444      fed previous output will first be multiplied by W and added B.
445    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
446      of decoder_inputs will be used (the "GO" symbol), and all other decoder
447      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
448      If False, decoder_inputs are used as given (the standard decoder case).
449    dtype: The dtype to use for the initial RNN states (default: tf.float32).
450    scope: VariableScope for the created subgraph; defaults to
451      "embedding_tied_rnn_seq2seq".
452
453  Returns:
454    A tuple of the form (outputs, state), where:
455      outputs: A list of the same length as decoder_inputs of 2D Tensors with
456        shape [batch_size x output_symbols] containing the generated
457        outputs where output_symbols = num_decoder_symbols if
458        num_decoder_symbols is not None otherwise output_symbols = num_symbols.
459      state: The state of each decoder cell at the final time-step.
460        It is a 2D Tensor of shape [batch_size x cell.state_size].
461
462  Raises:
463    ValueError: When output_projection has the wrong shape.
464  """
465  with variable_scope.variable_scope(
466      scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope:
467    dtype = scope.dtype
468
469    if output_projection is not None:
470      proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
471      proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
472      proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
473      proj_biases.get_shape().assert_is_compatible_with([num_symbols])
474
475    embedding = variable_scope.get_variable(
476        "embedding", [num_symbols, embedding_size], dtype=dtype)
477
478    emb_encoder_inputs = [
479        embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs
480    ]
481    emb_decoder_inputs = [
482        embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs
483    ]
484
485    output_symbols = num_symbols
486    if num_decoder_symbols is not None:
487      output_symbols = num_decoder_symbols
488    if output_projection is None:
489      cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols)
490
491    if isinstance(feed_previous, bool):
492      loop_function = _extract_argmax_and_embed(embedding, output_projection,
493                                                True) if feed_previous else None
494      return tied_rnn_seq2seq(
495          emb_encoder_inputs,
496          emb_decoder_inputs,
497          cell,
498          loop_function=loop_function,
499          dtype=dtype)
500
501    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
502    def decoder(feed_previous_bool):
503      loop_function = _extract_argmax_and_embed(
504          embedding, output_projection, False) if feed_previous_bool else None
505      reuse = None if feed_previous_bool else True
506      with variable_scope.variable_scope(
507          variable_scope.get_variable_scope(), reuse=reuse):
508        outputs, state = tied_rnn_seq2seq(
509            emb_encoder_inputs,
510            emb_decoder_inputs,
511            cell,
512            loop_function=loop_function,
513            dtype=dtype)
514        state_list = [state]
515        if nest.is_sequence(state):
516          state_list = nest.flatten(state)
517        return outputs + state_list
518
519    outputs_and_state = control_flow_ops.cond(feed_previous,
520                                              lambda: decoder(True),
521                                              lambda: decoder(False))
522    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
523    state_list = outputs_and_state[outputs_len:]
524    state = state_list[0]
525    # Calculate zero-state to know it's structure.
526    static_batch_size = encoder_inputs[0].get_shape()[0]
527    for inp in encoder_inputs[1:]:
528      static_batch_size.merge_with(inp.get_shape()[0])
529    batch_size = static_batch_size.value
530    if batch_size is None:
531      batch_size = array_ops.shape(encoder_inputs[0])[0]
532    zero_state = cell.zero_state(batch_size, dtype)
533    if nest.is_sequence(zero_state):
534      state = nest.pack_sequence_as(
535          structure=zero_state, flat_sequence=state_list)
536    return outputs_and_state[:outputs_len], state
537
538
539def attention_decoder(decoder_inputs,
540                      initial_state,
541                      attention_states,
542                      cell,
543                      output_size=None,
544                      num_heads=1,
545                      loop_function=None,
546                      dtype=None,
547                      scope=None,
548                      initial_state_attention=False):
549  """RNN decoder with attention for the sequence-to-sequence model.
550
551  In this context "attention" means that, during decoding, the RNN can look up
552  information in the additional tensor attention_states, and it does this by
553  focusing on a few entries from the tensor. This model has proven to yield
554  especially good results in a number of sequence-to-sequence tasks. This
555  implementation is based on http://arxiv.org/abs/1412.7449 (see below for
556  details). It is recommended for complex sequence-to-sequence tasks.
557
558  Args:
559    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
560    initial_state: 2D Tensor [batch_size x cell.state_size].
561    attention_states: 3D Tensor [batch_size x attn_length x attn_size].
562    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
563    output_size: Size of the output vectors; if None, we use cell.output_size.
564    num_heads: Number of attention heads that read from attention_states.
565    loop_function: If not None, this function will be applied to i-th output
566      in order to generate i+1-th input, and decoder_inputs will be ignored,
567      except for the first element ("GO" symbol). This can be used for decoding,
568      but also for training to emulate http://arxiv.org/abs/1506.03099.
569      Signature -- loop_function(prev, i) = next
570        * prev is a 2D Tensor of shape [batch_size x output_size],
571        * i is an integer, the step number (when advanced control is needed),
572        * next is a 2D Tensor of shape [batch_size x input_size].
573    dtype: The dtype to use for the RNN initial state (default: tf.float32).
574    scope: VariableScope for the created subgraph; default: "attention_decoder".
575    initial_state_attention: If False (default), initial attentions are zero.
576      If True, initialize the attentions from the initial state and attention
577      states -- useful when we wish to resume decoding from a previously
578      stored decoder state and attention states.
579
580  Returns:
581    A tuple of the form (outputs, state), where:
582      outputs: A list of the same length as decoder_inputs of 2D Tensors of
583        shape [batch_size x output_size]. These represent the generated outputs.
584        Output i is computed from input i (which is either the i-th element
585        of decoder_inputs or loop_function(output {i-1}, i)) as follows.
586        First, we run the cell on a combination of the input and previous
587        attention masks:
588          cell_output, new_state = cell(linear(input, prev_attn), prev_state).
589        Then, we calculate new attention masks:
590          new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
591        and then we calculate the output:
592          output = linear(cell_output, new_attn).
593      state: The state of each decoder cell the final time-step.
594        It is a 2D Tensor of shape [batch_size x cell.state_size].
595
596  Raises:
597    ValueError: when num_heads is not positive, there are no inputs, shapes
598      of attention_states are not set, or input size cannot be inferred
599      from the input.
600  """
601  if not decoder_inputs:
602    raise ValueError("Must provide at least 1 input to attention decoder.")
603  if num_heads < 1:
604    raise ValueError("With less than 1 heads, use a non-attention decoder.")
605  if attention_states.get_shape()[2].value is None:
606    raise ValueError("Shape[2] of attention_states must be known: %s" %
607                     attention_states.get_shape())
608  if output_size is None:
609    output_size = cell.output_size
610
611  with variable_scope.variable_scope(
612      scope or "attention_decoder", dtype=dtype) as scope:
613    dtype = scope.dtype
614
615    batch_size = array_ops.shape(decoder_inputs[0])[0]  # Needed for reshaping.
616    attn_length = attention_states.get_shape()[1].value
617    if attn_length is None:
618      attn_length = array_ops.shape(attention_states)[1]
619    attn_size = attention_states.get_shape()[2].value
620
621    # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
622    hidden = array_ops.reshape(attention_states,
623                               [-1, attn_length, 1, attn_size])
624    hidden_features = []
625    v = []
626    attention_vec_size = attn_size  # Size of query vectors for attention.
627    for a in xrange(num_heads):
628      k = variable_scope.get_variable("AttnW_%d" % a,
629                                      [1, 1, attn_size, attention_vec_size])
630      hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
631      v.append(
632          variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))
633
634    state = initial_state
635
636    def attention(query):
637      """Put attention masks on hidden using hidden_features and query."""
638      ds = []  # Results of attention reads will be stored here.
639      if nest.is_sequence(query):  # If the query is a tuple, flatten it.
640        query_list = nest.flatten(query)
641        for q in query_list:  # Check that ndims == 2 if specified.
642          ndims = q.get_shape().ndims
643          if ndims:
644            assert ndims == 2
645        query = array_ops.concat(query_list, 1)
646      for a in xrange(num_heads):
647        with variable_scope.variable_scope("Attention_%d" % a):
648          y = Linear(query, attention_vec_size, True)(query)
649          y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
650          # Attention mask is a softmax of v^T * tanh(...).
651          s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
652                                  [2, 3])
653          a = nn_ops.softmax(s)
654          # Now calculate the attention-weighted vector d.
655          d = math_ops.reduce_sum(
656              array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
657          ds.append(array_ops.reshape(d, [-1, attn_size]))
658      return ds
659
660    outputs = []
661    prev = None
662    batch_attn_size = array_ops.stack([batch_size, attn_size])
663    attns = [
664        array_ops.zeros(
665            batch_attn_size, dtype=dtype) for _ in xrange(num_heads)
666    ]
667    for a in attns:  # Ensure the second shape of attention vectors is set.
668      a.set_shape([None, attn_size])
669    if initial_state_attention:
670      attns = attention(initial_state)
671    for i, inp in enumerate(decoder_inputs):
672      if i > 0:
673        variable_scope.get_variable_scope().reuse_variables()
674      # If loop_function is set, we use it instead of decoder_inputs.
675      if loop_function is not None and prev is not None:
676        with variable_scope.variable_scope("loop_function", reuse=True):
677          inp = loop_function(prev, i)
678      # Merge input and previous attentions into one vector of the right size.
679      input_size = inp.get_shape().with_rank(2)[1]
680      if input_size.value is None:
681        raise ValueError("Could not infer input size from input: %s" % inp.name)
682
683      inputs = [inp] + attns
684      x = Linear(inputs, input_size, True)(inputs)
685      # Run the RNN.
686      cell_output, state = cell(x, state)
687      # Run the attention mechanism.
688      if i == 0 and initial_state_attention:
689        with variable_scope.variable_scope(
690            variable_scope.get_variable_scope(), reuse=True):
691          attns = attention(state)
692      else:
693        attns = attention(state)
694
695      with variable_scope.variable_scope("AttnOutputProjection"):
696        inputs = [cell_output] + attns
697        output = Linear(inputs, output_size, True)(inputs)
698      if loop_function is not None:
699        prev = output
700      outputs.append(output)
701
702  return outputs, state
703
704
705def embedding_attention_decoder(decoder_inputs,
706                                initial_state,
707                                attention_states,
708                                cell,
709                                num_symbols,
710                                embedding_size,
711                                num_heads=1,
712                                output_size=None,
713                                output_projection=None,
714                                feed_previous=False,
715                                update_embedding_for_previous=True,
716                                dtype=None,
717                                scope=None,
718                                initial_state_attention=False):
719  """RNN decoder with embedding and attention and a pure-decoding option.
720
721  Args:
722    decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
723    initial_state: 2D Tensor [batch_size x cell.state_size].
724    attention_states: 3D Tensor [batch_size x attn_length x attn_size].
725    cell: tf.nn.rnn_cell.RNNCell defining the cell function.
726    num_symbols: Integer, how many symbols come into the embedding.
727    embedding_size: Integer, the length of the embedding vector for each symbol.
728    num_heads: Number of attention heads that read from attention_states.
729    output_size: Size of the output vectors; if None, use output_size.
730    output_projection: None or a pair (W, B) of output projection weights and
731      biases; W has shape [output_size x num_symbols] and B has shape
732      [num_symbols]; if provided and feed_previous=True, each fed previous
733      output will first be multiplied by W and added B.
734    feed_previous: Boolean; if True, only the first of decoder_inputs will be
735      used (the "GO" symbol), and all other decoder inputs will be generated by:
736        next = embedding_lookup(embedding, argmax(previous_output)),
737      In effect, this implements a greedy decoder. It can also be used
738      during training to emulate http://arxiv.org/abs/1506.03099.
739      If False, decoder_inputs are used as given (the standard decoder case).
740    update_embedding_for_previous: Boolean; if False and feed_previous=True,
741      only the embedding for the first symbol of decoder_inputs (the "GO"
742      symbol) will be updated by back propagation. Embeddings for the symbols
743      generated from the decoder itself remain unchanged. This parameter has
744      no effect if feed_previous=False.
745    dtype: The dtype to use for the RNN initial states (default: tf.float32).
746    scope: VariableScope for the created subgraph; defaults to
747      "embedding_attention_decoder".
748    initial_state_attention: If False (default), initial attentions are zero.
749      If True, initialize the attentions from the initial state and attention
750      states -- useful when we wish to resume decoding from a previously
751      stored decoder state and attention states.
752
753  Returns:
754    A tuple of the form (outputs, state), where:
755      outputs: A list of the same length as decoder_inputs of 2D Tensors with
756        shape [batch_size x output_size] containing the generated outputs.
757      state: The state of each decoder cell at the final time-step.
758        It is a 2D Tensor of shape [batch_size x cell.state_size].
759
760  Raises:
761    ValueError: When output_projection has the wrong shape.
762  """
763  if output_size is None:
764    output_size = cell.output_size
765  if output_projection is not None:
766    proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
767    proj_biases.get_shape().assert_is_compatible_with([num_symbols])
768
769  with variable_scope.variable_scope(
770      scope or "embedding_attention_decoder", dtype=dtype) as scope:
771
772    embedding = variable_scope.get_variable("embedding",
773                                            [num_symbols, embedding_size])
774    loop_function = _extract_argmax_and_embed(
775        embedding, output_projection,
776        update_embedding_for_previous) if feed_previous else None
777    emb_inp = [
778        embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs
779    ]
780    return attention_decoder(
781        emb_inp,
782        initial_state,
783        attention_states,
784        cell,
785        output_size=output_size,
786        num_heads=num_heads,
787        loop_function=loop_function,
788        initial_state_attention=initial_state_attention)
789
790
791def embedding_attention_seq2seq(encoder_inputs,
792                                decoder_inputs,
793                                cell,
794                                num_encoder_symbols,
795                                num_decoder_symbols,
796                                embedding_size,
797                                num_heads=1,
798                                output_projection=None,
799                                feed_previous=False,
800                                dtype=None,
801                                scope=None,
802                                initial_state_attention=False):
803  """Embedding sequence-to-sequence model with attention.
804
805  This model first embeds encoder_inputs by a newly created embedding (of shape
806  [num_encoder_symbols x input_size]). Then it runs an RNN to encode
807  embedded encoder_inputs into a state vector. It keeps the outputs of this
808  RNN at every step to use for attention later. Next, it embeds decoder_inputs
809  by another newly created embedding (of shape [num_decoder_symbols x
810  input_size]). Then it runs attention decoder, initialized with the last
811  encoder state, on embedded decoder_inputs and attending to encoder outputs.
812
813  Warning: when output_projection is None, the size of the attention vectors
814  and variables will be made proportional to num_decoder_symbols, can be large.
815
816  Args:
817    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
818    decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
819    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
820    num_encoder_symbols: Integer; number of symbols on the encoder side.
821    num_decoder_symbols: Integer; number of symbols on the decoder side.
822    embedding_size: Integer, the length of the embedding vector for each symbol.
823    num_heads: Number of attention heads that read from attention_states.
824    output_projection: None or a pair (W, B) of output projection weights and
825      biases; W has shape [output_size x num_decoder_symbols] and B has
826      shape [num_decoder_symbols]; if provided and feed_previous=True, each
827      fed previous output will first be multiplied by W and added B.
828    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
829      of decoder_inputs will be used (the "GO" symbol), and all other decoder
830      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
831      If False, decoder_inputs are used as given (the standard decoder case).
832    dtype: The dtype of the initial RNN state (default: tf.float32).
833    scope: VariableScope for the created subgraph; defaults to
834      "embedding_attention_seq2seq".
835    initial_state_attention: If False (default), initial attentions are zero.
836      If True, initialize the attentions from the initial state and attention
837      states.
838
839  Returns:
840    A tuple of the form (outputs, state), where:
841      outputs: A list of the same length as decoder_inputs of 2D Tensors with
842        shape [batch_size x num_decoder_symbols] containing the generated
843        outputs.
844      state: The state of each decoder cell at the final time-step.
845        It is a 2D Tensor of shape [batch_size x cell.state_size].
846  """
847  with variable_scope.variable_scope(
848      scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
849    dtype = scope.dtype
850    # Encoder.
851    encoder_cell = copy.deepcopy(cell)
852    encoder_cell = core_rnn_cell.EmbeddingWrapper(
853        encoder_cell,
854        embedding_classes=num_encoder_symbols,
855        embedding_size=embedding_size)
856    encoder_outputs, encoder_state = rnn.static_rnn(
857        encoder_cell, encoder_inputs, dtype=dtype)
858
859    # First calculate a concatenation of encoder outputs to put attention on.
860    top_states = [
861        array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs
862    ]
863    attention_states = array_ops.concat(top_states, 1)
864
865    # Decoder.
866    output_size = None
867    if output_projection is None:
868      cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
869      output_size = num_decoder_symbols
870
871    if isinstance(feed_previous, bool):
872      return embedding_attention_decoder(
873          decoder_inputs,
874          encoder_state,
875          attention_states,
876          cell,
877          num_decoder_symbols,
878          embedding_size,
879          num_heads=num_heads,
880          output_size=output_size,
881          output_projection=output_projection,
882          feed_previous=feed_previous,
883          initial_state_attention=initial_state_attention)
884
885    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
886    def decoder(feed_previous_bool):
887      reuse = None if feed_previous_bool else True
888      with variable_scope.variable_scope(
889          variable_scope.get_variable_scope(), reuse=reuse):
890        outputs, state = embedding_attention_decoder(
891            decoder_inputs,
892            encoder_state,
893            attention_states,
894            cell,
895            num_decoder_symbols,
896            embedding_size,
897            num_heads=num_heads,
898            output_size=output_size,
899            output_projection=output_projection,
900            feed_previous=feed_previous_bool,
901            update_embedding_for_previous=False,
902            initial_state_attention=initial_state_attention)
903        state_list = [state]
904        if nest.is_sequence(state):
905          state_list = nest.flatten(state)
906        return outputs + state_list
907
908    outputs_and_state = control_flow_ops.cond(feed_previous,
909                                              lambda: decoder(True),
910                                              lambda: decoder(False))
911    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
912    state_list = outputs_and_state[outputs_len:]
913    state = state_list[0]
914    if nest.is_sequence(encoder_state):
915      state = nest.pack_sequence_as(
916          structure=encoder_state, flat_sequence=state_list)
917    return outputs_and_state[:outputs_len], state
918
919
920def one2many_rnn_seq2seq(encoder_inputs,
921                         decoder_inputs_dict,
922                         enc_cell,
923                         dec_cells_dict,
924                         num_encoder_symbols,
925                         num_decoder_symbols_dict,
926                         embedding_size,
927                         feed_previous=False,
928                         dtype=None,
929                         scope=None):
930  """One-to-many RNN sequence-to-sequence model (multi-task).
931
932  This is a multi-task sequence-to-sequence model with one encoder and multiple
933  decoders. Reference to multi-task sequence-to-sequence learning can be found
934  here: http://arxiv.org/abs/1511.06114
935
936  Args:
937    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
938    decoder_inputs_dict: A dictionary mapping decoder name (string) to
939      the corresponding decoder_inputs; each decoder_inputs is a list of 1D
940      Tensors of shape [batch_size]; num_decoders is defined as
941      len(decoder_inputs_dict).
942    enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and
943      size.
944    dec_cells_dict: A dictionary mapping encoder name (string) to an
945      instance of tf.nn.rnn_cell.RNNCell.
946    num_encoder_symbols: Integer; number of symbols on the encoder side.
947    num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
948      integer specifying number of symbols for the corresponding decoder;
949      len(num_decoder_symbols_dict) must be equal to num_decoders.
950    embedding_size: Integer, the length of the embedding vector for each symbol.
951    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of
952      decoder_inputs will be used (the "GO" symbol), and all other decoder
953      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
954      If False, decoder_inputs are used as given (the standard decoder case).
955    dtype: The dtype of the initial state for both the encoder and encoder
956      rnn cells (default: tf.float32).
957    scope: VariableScope for the created subgraph; defaults to
958      "one2many_rnn_seq2seq"
959
960  Returns:
961    A tuple of the form (outputs_dict, state_dict), where:
962      outputs_dict: A mapping from decoder name (string) to a list of the same
963        length as decoder_inputs_dict[name]; each element in the list is a 2D
964        Tensors with shape [batch_size x num_decoder_symbol_list[name]]
965        containing the generated outputs.
966      state_dict: A mapping from decoder name (string) to the final state of the
967        corresponding decoder RNN; it is a 2D Tensor of shape
968        [batch_size x cell.state_size].
969
970  Raises:
971    TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell.
972    ValueError: if len(dec_cells) != len(decoder_inputs_dict).
973  """
974  outputs_dict = {}
975  state_dict = {}
976
977  if not isinstance(enc_cell, rnn_cell_impl.RNNCell):
978    raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
979  if set(dec_cells_dict) != set(decoder_inputs_dict):
980    raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
981  for dec_cell in dec_cells_dict.values():
982    if not isinstance(dec_cell, rnn_cell_impl.RNNCell):
983      raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
984
985  with variable_scope.variable_scope(
986      scope or "one2many_rnn_seq2seq", dtype=dtype) as scope:
987    dtype = scope.dtype
988
989    # Encoder.
990    enc_cell = core_rnn_cell.EmbeddingWrapper(
991        enc_cell,
992        embedding_classes=num_encoder_symbols,
993        embedding_size=embedding_size)
994    _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
995
996    # Decoder.
997    for name, decoder_inputs in decoder_inputs_dict.items():
998      num_decoder_symbols = num_decoder_symbols_dict[name]
999      dec_cell = dec_cells_dict[name]
1000
1001      with variable_scope.variable_scope("one2many_decoder_" + str(
1002          name)) as scope:
1003        dec_cell = core_rnn_cell.OutputProjectionWrapper(
1004            dec_cell, num_decoder_symbols)
1005        if isinstance(feed_previous, bool):
1006          outputs, state = embedding_rnn_decoder(
1007              decoder_inputs,
1008              encoder_state,
1009              dec_cell,
1010              num_decoder_symbols,
1011              embedding_size,
1012              feed_previous=feed_previous)
1013        else:
1014          # If feed_previous is a Tensor, we construct 2 graphs and use cond.
1015          def filled_embedding_rnn_decoder(feed_previous):
1016            """The current decoder with a fixed feed_previous parameter."""
1017            # pylint: disable=cell-var-from-loop
1018            reuse = None if feed_previous else True
1019            vs = variable_scope.get_variable_scope()
1020            with variable_scope.variable_scope(vs, reuse=reuse):
1021              outputs, state = embedding_rnn_decoder(
1022                  decoder_inputs,
1023                  encoder_state,
1024                  dec_cell,
1025                  num_decoder_symbols,
1026                  embedding_size,
1027                  feed_previous=feed_previous)
1028            # pylint: enable=cell-var-from-loop
1029            state_list = [state]
1030            if nest.is_sequence(state):
1031              state_list = nest.flatten(state)
1032            return outputs + state_list
1033
1034          outputs_and_state = control_flow_ops.cond(
1035              feed_previous, lambda: filled_embedding_rnn_decoder(True),
1036              lambda: filled_embedding_rnn_decoder(False))
1037          # Outputs length is the same as for decoder inputs.
1038          outputs_len = len(decoder_inputs)
1039          outputs = outputs_and_state[:outputs_len]
1040          state_list = outputs_and_state[outputs_len:]
1041          state = state_list[0]
1042          if nest.is_sequence(encoder_state):
1043            state = nest.pack_sequence_as(
1044                structure=encoder_state, flat_sequence=state_list)
1045      outputs_dict[name] = outputs
1046      state_dict[name] = state
1047
1048  return outputs_dict, state_dict
1049
1050
1051def sequence_loss_by_example(logits,
1052                             targets,
1053                             weights,
1054                             average_across_timesteps=True,
1055                             softmax_loss_function=None,
1056                             name=None):
1057  """Weighted cross-entropy loss for a sequence of logits (per example).
1058
1059  Args:
1060    logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
1061    targets: List of 1D batch-sized int32 Tensors of the same length as logits.
1062    weights: List of 1D batch-sized float-Tensors of the same length as logits.
1063    average_across_timesteps: If set, divide the returned cost by the total
1064      label weight.
1065    softmax_loss_function: Function (labels, logits) -> loss-batch
1066      to be used instead of the standard softmax (the default if this is None).
1067      **Note that to avoid confusion, it is required for the function to accept
1068      named arguments.**
1069    name: Optional name for this operation, default: "sequence_loss_by_example".
1070
1071  Returns:
1072    1D batch-sized float Tensor: The log-perplexity for each sequence.
1073
1074  Raises:
1075    ValueError: If len(logits) is different from len(targets) or len(weights).
1076  """
1077  if len(targets) != len(logits) or len(weights) != len(logits):
1078    raise ValueError("Lengths of logits, weights, and targets must be the same "
1079                     "%d, %d, %d." % (len(logits), len(weights), len(targets)))
1080  with ops.name_scope(name, "sequence_loss_by_example",
1081                      logits + targets + weights):
1082    log_perp_list = []
1083    for logit, target, weight in zip(logits, targets, weights):
1084      if softmax_loss_function is None:
1085        # TODO(irving,ebrevdo): This reshape is needed because
1086        # sequence_loss_by_example is called with scalars sometimes, which
1087        # violates our general scalar strictness policy.
1088        target = array_ops.reshape(target, [-1])
1089        crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
1090            labels=target, logits=logit)
1091      else:
1092        crossent = softmax_loss_function(labels=target, logits=logit)
1093      log_perp_list.append(crossent * weight)
1094    log_perps = math_ops.add_n(log_perp_list)
1095    if average_across_timesteps:
1096      total_size = math_ops.add_n(weights)
1097      total_size += 1e-12  # Just to avoid division by 0 for all-0 weights.
1098      log_perps /= total_size
1099  return log_perps
1100
1101
1102def sequence_loss(logits,
1103                  targets,
1104                  weights,
1105                  average_across_timesteps=True,
1106                  average_across_batch=True,
1107                  softmax_loss_function=None,
1108                  name=None):
1109  """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
1110
1111  Args:
1112    logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
1113    targets: List of 1D batch-sized int32 Tensors of the same length as logits.
1114    weights: List of 1D batch-sized float-Tensors of the same length as logits.
1115    average_across_timesteps: If set, divide the returned cost by the total
1116      label weight.
1117    average_across_batch: If set, divide the returned cost by the batch size.
1118    softmax_loss_function: Function (labels, logits) -> loss-batch
1119      to be used instead of the standard softmax (the default if this is None).
1120      **Note that to avoid confusion, it is required for the function to accept
1121      named arguments.**
1122    name: Optional name for this operation, defaults to "sequence_loss".
1123
1124  Returns:
1125    A scalar float Tensor: The average log-perplexity per symbol (weighted).
1126
1127  Raises:
1128    ValueError: If len(logits) is different from len(targets) or len(weights).
1129  """
1130  with ops.name_scope(name, "sequence_loss", logits + targets + weights):
1131    cost = math_ops.reduce_sum(
1132        sequence_loss_by_example(
1133            logits,
1134            targets,
1135            weights,
1136            average_across_timesteps=average_across_timesteps,
1137            softmax_loss_function=softmax_loss_function))
1138    if average_across_batch:
1139      batch_size = array_ops.shape(targets[0])[0]
1140      return cost / math_ops.cast(batch_size, cost.dtype)
1141    else:
1142      return cost
1143
1144
1145def model_with_buckets(encoder_inputs,
1146                       decoder_inputs,
1147                       targets,
1148                       weights,
1149                       buckets,
1150                       seq2seq,
1151                       softmax_loss_function=None,
1152                       per_example_loss=False,
1153                       name=None):
1154  """Create a sequence-to-sequence model with support for bucketing.
1155
1156  The seq2seq argument is a function that defines a sequence-to-sequence model,
1157  e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
1158      x, y, rnn_cell.GRUCell(24))
1159
1160  Args:
1161    encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
1162    decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input.
1163    targets: A list of 1D batch-sized int32 Tensors (desired output sequence).
1164    weights: List of 1D batch-sized float-Tensors to weight the targets.
1165    buckets: A list of pairs of (input size, output size) for each bucket.
1166    seq2seq: A sequence-to-sequence model function; it takes 2 input that
1167      agree with encoder_inputs and decoder_inputs, and returns a pair
1168      consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
1169    softmax_loss_function: Function (labels, logits) -> loss-batch
1170      to be used instead of the standard softmax (the default if this is None).
1171      **Note that to avoid confusion, it is required for the function to accept
1172      named arguments.**
1173    per_example_loss: Boolean. If set, the returned loss will be a batch-sized
1174      tensor of losses for each sequence in the batch. If unset, it will be
1175      a scalar with the averaged loss from all examples.
1176    name: Optional name for this operation, defaults to "model_with_buckets".
1177
1178  Returns:
1179    A tuple of the form (outputs, losses), where:
1180      outputs: The outputs for each bucket. Its j'th element consists of a list
1181        of 2D Tensors. The shape of output tensors can be either
1182        [batch_size x output_size] or [batch_size x num_decoder_symbols]
1183        depending on the seq2seq model used.
1184      losses: List of scalar Tensors, representing losses for each bucket, or,
1185        if per_example_loss is set, a list of 1D batch-sized float Tensors.
1186
1187  Raises:
1188    ValueError: If length of encoder_inputs, targets, or weights is smaller
1189      than the largest (last) bucket.
1190  """
1191  if len(encoder_inputs) < buckets[-1][0]:
1192    raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
1193                     "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
1194  if len(targets) < buckets[-1][1]:
1195    raise ValueError("Length of targets (%d) must be at least that of last "
1196                     "bucket (%d)." % (len(targets), buckets[-1][1]))
1197  if len(weights) < buckets[-1][1]:
1198    raise ValueError("Length of weights (%d) must be at least that of last "
1199                     "bucket (%d)." % (len(weights), buckets[-1][1]))
1200
1201  all_inputs = encoder_inputs + decoder_inputs + targets + weights
1202  losses = []
1203  outputs = []
1204  with ops.name_scope(name, "model_with_buckets", all_inputs):
1205    for j, bucket in enumerate(buckets):
1206      with variable_scope.variable_scope(
1207          variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
1208        bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
1209                                    decoder_inputs[:bucket[1]])
1210        outputs.append(bucket_outputs)
1211        if per_example_loss:
1212          losses.append(
1213              sequence_loss_by_example(
1214                  outputs[-1],
1215                  targets[:bucket[1]],
1216                  weights[:bucket[1]],
1217                  softmax_loss_function=softmax_loss_function))
1218        else:
1219          losses.append(
1220              sequence_loss(
1221                  outputs[-1],
1222                  targets[:bucket[1]],
1223                  weights[:bucket[1]],
1224                  softmax_loss_function=softmax_loss_function))
1225
1226  return outputs, losses
1227