• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for creating sequence-to-sequence models in TensorFlow.
16
17Sequence-to-sequence recurrent neural networks can learn complex functions
18that map input sequences to output sequences. These models yield very good
19results on a number of tasks, such as speech recognition, parsing, machine
20translation, or even constructing automated replies to emails.
21
22Before using this module, it is recommended to read the TensorFlow tutorial
23on sequence-to-sequence models. It explains the basic concepts of this module
24and shows an end-to-end example of how to build a translation model.
25  https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html
26
27Here is an overview of functions available in this module. They all use
28a very similar interface, so after reading the above tutorial and using
29one of them, others should be easy to substitute.
30
31* Full sequence-to-sequence models.
32  - basic_rnn_seq2seq: The most basic RNN-RNN model.
33  - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights.
34  - embedding_rnn_seq2seq: The basic model with input embedding.
35  - embedding_tied_rnn_seq2seq: The tied model with input embedding.
36  - embedding_attention_seq2seq: Advanced model with input embedding and
37      the neural attention mechanism; recommended for complex tasks.
38
39* Multi-task sequence-to-sequence models.
40  - one2many_rnn_seq2seq: The embedding model with multiple decoders.
41
42* Decoders (when you write your own encoder, you can use these to decode;
43    e.g., if you want to write a model that generates captions for images).
44  - rnn_decoder: The basic decoder based on a pure RNN.
45  - attention_decoder: A decoder that uses the attention mechanism.
46
47* Losses.
48  - sequence_loss: Loss for a sequence model returning average log-perplexity.
49  - sequence_loss_by_example: As above, but not averaging over all examples.
50
51* model_with_buckets: A convenience function to create models with bucketing
52    (see the tutorial above for an explanation of why and how to use it).
53"""
54
55from __future__ import absolute_import
56from __future__ import division
57from __future__ import print_function
58
59import copy
60
61# We disable pylint because we need python3 compatibility.
62from six.moves import xrange  # pylint: disable=redefined-builtin
63from six.moves import zip  # pylint: disable=redefined-builtin
64
65from tensorflow.contrib.rnn.python.ops import core_rnn_cell
66from tensorflow.python.framework import dtypes
67from tensorflow.python.framework import ops
68from tensorflow.python.ops import array_ops
69from tensorflow.python.ops import control_flow_ops
70from tensorflow.python.ops import embedding_ops
71from tensorflow.python.ops import math_ops
72from tensorflow.python.ops import nn_ops
73from tensorflow.python.ops import rnn
74from tensorflow.python.ops import rnn_cell_impl
75from tensorflow.python.ops import variable_scope
76from tensorflow.python.util import nest
77
78# TODO(ebrevdo): Remove once _linear is fully deprecated.
79Linear = core_rnn_cell._Linear  # pylint: disable=protected-access,invalid-name
80
81
82def _extract_argmax_and_embed(embedding,
83                              output_projection=None,
84                              update_embedding=True):
85  """Get a loop_function that extracts the previous symbol and embeds it.
86
87  Args:
88    embedding: embedding tensor for symbols.
89    output_projection: None or a pair (W, B). If provided, each fed previous
90      output will first be multiplied by W and added B.
91    update_embedding: Boolean; if False, the gradients will not propagate
92      through the embeddings.
93
94  Returns:
95    A loop function.
96  """
97
98  def loop_function(prev, _):
99    if output_projection is not None:
100      prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
101    prev_symbol = math_ops.argmax(prev, 1)
102    # Note that gradients will not propagate through the second parameter of
103    # embedding_lookup.
104    emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
105    if not update_embedding:
106      emb_prev = array_ops.stop_gradient(emb_prev)
107    return emb_prev
108
109  return loop_function
110
111
112def rnn_decoder(decoder_inputs,
113                initial_state,
114                cell,
115                loop_function=None,
116                scope=None):
117  """RNN decoder for the sequence-to-sequence model.
118
119  Args:
120    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
121    initial_state: 2D Tensor with shape [batch_size x cell.state_size].
122    cell: rnn_cell.RNNCell defining the cell function and size.
123    loop_function: If not None, this function will be applied to the i-th output
124      in order to generate the i+1-st input, and decoder_inputs will be ignored,
125      except for the first element ("GO" symbol). This can be used for decoding,
126      but also for training to emulate http://arxiv.org/abs/1506.03099.
127      Signature -- loop_function(prev, i) = next
128        * prev is a 2D Tensor of shape [batch_size x output_size],
129        * i is an integer, the step number (when advanced control is needed),
130        * next is a 2D Tensor of shape [batch_size x input_size].
131    scope: VariableScope for the created subgraph; defaults to "rnn_decoder".
132
133  Returns:
134    A tuple of the form (outputs, state), where:
135      outputs: A list of the same length as decoder_inputs of 2D Tensors with
136        shape [batch_size x output_size] containing generated outputs.
137      state: The state of each cell at the final time-step.
138        It is a 2D Tensor of shape [batch_size x cell.state_size].
139        (Note that in some cases, like basic RNN cell or GRU cell, outputs and
140         states can be the same. They are different for LSTM cells though.)
141  """
142  with variable_scope.variable_scope(scope or "rnn_decoder"):
143    state = initial_state
144    outputs = []
145    prev = None
146    for i, inp in enumerate(decoder_inputs):
147      if loop_function is not None and prev is not None:
148        with variable_scope.variable_scope("loop_function", reuse=True):
149          inp = loop_function(prev, i)
150      if i > 0:
151        variable_scope.get_variable_scope().reuse_variables()
152      output, state = cell(inp, state)
153      outputs.append(output)
154      if loop_function is not None:
155        prev = output
156  return outputs, state
157
158
159def basic_rnn_seq2seq(encoder_inputs,
160                      decoder_inputs,
161                      cell,
162                      dtype=dtypes.float32,
163                      scope=None):
164  """Basic RNN sequence-to-sequence model.
165
166  This model first runs an RNN to encode encoder_inputs into a state vector,
167  then runs decoder, initialized with the last encoder state, on decoder_inputs.
168  Encoder and decoder use the same RNN cell type, but don't share parameters.
169
170  Args:
171    encoder_inputs: A list of 2D Tensors [batch_size x input_size].
172    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
173    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
174    dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
175    scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".
176
177  Returns:
178    A tuple of the form (outputs, state), where:
179      outputs: A list of the same length as decoder_inputs of 2D Tensors with
180        shape [batch_size x output_size] containing the generated outputs.
181      state: The state of each decoder cell in the final time-step.
182        It is a 2D Tensor of shape [batch_size x cell.state_size].
183  """
184  with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
185    enc_cell = copy.deepcopy(cell)
186    _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
187    return rnn_decoder(decoder_inputs, enc_state, cell)
188
189
190def tied_rnn_seq2seq(encoder_inputs,
191                     decoder_inputs,
192                     cell,
193                     loop_function=None,
194                     dtype=dtypes.float32,
195                     scope=None):
196  """RNN sequence-to-sequence model with tied encoder and decoder parameters.
197
198  This model first runs an RNN to encode encoder_inputs into a state vector, and
199  then runs decoder, initialized with the last encoder state, on decoder_inputs.
200  Encoder and decoder use the same RNN cell and share parameters.
201
202  Args:
203    encoder_inputs: A list of 2D Tensors [batch_size x input_size].
204    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
205    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
206    loop_function: If not None, this function will be applied to i-th output
207      in order to generate i+1-th input, and decoder_inputs will be ignored,
208      except for the first element ("GO" symbol), see rnn_decoder for details.
209    dtype: The dtype of the initial state of the rnn cell (default: tf.float32).
210    scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq".
211
212  Returns:
213    A tuple of the form (outputs, state), where:
214      outputs: A list of the same length as decoder_inputs of 2D Tensors with
215        shape [batch_size x output_size] containing the generated outputs.
216      state: The state of each decoder cell in each time-step. This is a list
217        with length len(decoder_inputs) -- one item for each time-step.
218        It is a 2D Tensor of shape [batch_size x cell.state_size].
219  """
220  with variable_scope.variable_scope("combined_tied_rnn_seq2seq"):
221    scope = scope or "tied_rnn_seq2seq"
222    _, enc_state = rnn.static_rnn(
223        cell, encoder_inputs, dtype=dtype, scope=scope)
224    variable_scope.get_variable_scope().reuse_variables()
225    return rnn_decoder(
226        decoder_inputs,
227        enc_state,
228        cell,
229        loop_function=loop_function,
230        scope=scope)
231
232
233def embedding_rnn_decoder(decoder_inputs,
234                          initial_state,
235                          cell,
236                          num_symbols,
237                          embedding_size,
238                          output_projection=None,
239                          feed_previous=False,
240                          update_embedding_for_previous=True,
241                          scope=None):
242  """RNN decoder with embedding and a pure-decoding option.
243
244  Args:
245    decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
246    initial_state: 2D Tensor [batch_size x cell.state_size].
247    cell: tf.nn.rnn_cell.RNNCell defining the cell function.
248    num_symbols: Integer, how many symbols come into the embedding.
249    embedding_size: Integer, the length of the embedding vector for each symbol.
250    output_projection: None or a pair (W, B) of output projection weights and
251      biases; W has shape [output_size x num_symbols] and B has
252      shape [num_symbols]; if provided and feed_previous=True, each fed
253      previous output will first be multiplied by W and added B.
254    feed_previous: Boolean; if True, only the first of decoder_inputs will be
255      used (the "GO" symbol), and all other decoder inputs will be generated by:
256        next = embedding_lookup(embedding, argmax(previous_output)),
257      In effect, this implements a greedy decoder. It can also be used
258      during training to emulate http://arxiv.org/abs/1506.03099.
259      If False, decoder_inputs are used as given (the standard decoder case).
260    update_embedding_for_previous: Boolean; if False and feed_previous=True,
261      only the embedding for the first symbol of decoder_inputs (the "GO"
262      symbol) will be updated by back propagation. Embeddings for the symbols
263      generated from the decoder itself remain unchanged. This parameter has
264      no effect if feed_previous=False.
265    scope: VariableScope for the created subgraph; defaults to
266      "embedding_rnn_decoder".
267
268  Returns:
269    A tuple of the form (outputs, state), where:
270      outputs: A list of the same length as decoder_inputs of 2D Tensors. The
271        output is of shape [batch_size x cell.output_size] when
272        output_projection is not None (and represents the dense representation
273        of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
274        when output_projection is None.
275      state: The state of each decoder cell in each time-step. This is a list
276        with length len(decoder_inputs) -- one item for each time-step.
277        It is a 2D Tensor of shape [batch_size x cell.state_size].
278
279  Raises:
280    ValueError: When output_projection has the wrong shape.
281  """
282  with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope:
283    if output_projection is not None:
284      dtype = scope.dtype
285      proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
286      proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
287      proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
288      proj_biases.get_shape().assert_is_compatible_with([num_symbols])
289
290    embedding = variable_scope.get_variable("embedding",
291                                            [num_symbols, embedding_size])
292    loop_function = _extract_argmax_and_embed(
293        embedding, output_projection,
294        update_embedding_for_previous) if feed_previous else None
295    emb_inp = (embedding_ops.embedding_lookup(embedding, i)
296               for i in decoder_inputs)
297    return rnn_decoder(
298        emb_inp, initial_state, cell, loop_function=loop_function)
299
300
301def embedding_rnn_seq2seq(encoder_inputs,
302                          decoder_inputs,
303                          cell,
304                          num_encoder_symbols,
305                          num_decoder_symbols,
306                          embedding_size,
307                          output_projection=None,
308                          feed_previous=False,
309                          dtype=None,
310                          scope=None):
311  """Embedding RNN sequence-to-sequence model.
312
313  This model first embeds encoder_inputs by a newly created embedding (of shape
314  [num_encoder_symbols x input_size]). Then it runs an RNN to encode
315  embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs
316  by another newly created embedding (of shape [num_decoder_symbols x
317  input_size]). Then it runs RNN decoder, initialized with the last
318  encoder state, on embedded decoder_inputs.
319
320  Args:
321    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
322    decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
323    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
324    num_encoder_symbols: Integer; number of symbols on the encoder side.
325    num_decoder_symbols: Integer; number of symbols on the decoder side.
326    embedding_size: Integer, the length of the embedding vector for each symbol.
327    output_projection: None or a pair (W, B) of output projection weights and
328      biases; W has shape [output_size x num_decoder_symbols] and B has
329      shape [num_decoder_symbols]; if provided and feed_previous=True, each
330      fed previous output will first be multiplied by W and added B.
331    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
332      of decoder_inputs will be used (the "GO" symbol), and all other decoder
333      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
334      If False, decoder_inputs are used as given (the standard decoder case).
335    dtype: The dtype of the initial state for both the encoder and encoder
336      rnn cells (default: tf.float32).
337    scope: VariableScope for the created subgraph; defaults to
338      "embedding_rnn_seq2seq"
339
340  Returns:
341    A tuple of the form (outputs, state), where:
342      outputs: A list of the same length as decoder_inputs of 2D Tensors. The
343        output is of shape [batch_size x cell.output_size] when
344        output_projection is not None (and represents the dense representation
345        of predicted tokens). It is of shape [batch_size x num_decoder_symbols]
346        when output_projection is None.
347      state: The state of each decoder cell in each time-step. This is a list
348        with length len(decoder_inputs) -- one item for each time-step.
349        It is a 2D Tensor of shape [batch_size x cell.state_size].
350  """
351  with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
352    if dtype is not None:
353      scope.set_dtype(dtype)
354    else:
355      dtype = scope.dtype
356
357    # Encoder.
358    encoder_cell = copy.deepcopy(cell)
359    encoder_cell = core_rnn_cell.EmbeddingWrapper(
360        encoder_cell,
361        embedding_classes=num_encoder_symbols,
362        embedding_size=embedding_size)
363    _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)
364
365    # Decoder.
366    if output_projection is None:
367      cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
368
369    if isinstance(feed_previous, bool):
370      return embedding_rnn_decoder(
371          decoder_inputs,
372          encoder_state,
373          cell,
374          num_decoder_symbols,
375          embedding_size,
376          output_projection=output_projection,
377          feed_previous=feed_previous)
378
379    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
380    def decoder(feed_previous_bool):
381      reuse = None if feed_previous_bool else True
382      with variable_scope.variable_scope(
383          variable_scope.get_variable_scope(), reuse=reuse):
384        outputs, state = embedding_rnn_decoder(
385            decoder_inputs,
386            encoder_state,
387            cell,
388            num_decoder_symbols,
389            embedding_size,
390            output_projection=output_projection,
391            feed_previous=feed_previous_bool,
392            update_embedding_for_previous=False)
393        state_list = [state]
394        if nest.is_sequence(state):
395          state_list = nest.flatten(state)
396        return outputs + state_list
397
398    outputs_and_state = control_flow_ops.cond(feed_previous,
399                                              lambda: decoder(True),
400                                              lambda: decoder(False))
401    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
402    state_list = outputs_and_state[outputs_len:]
403    state = state_list[0]
404    if nest.is_sequence(encoder_state):
405      state = nest.pack_sequence_as(
406          structure=encoder_state, flat_sequence=state_list)
407    return outputs_and_state[:outputs_len], state
408
409
410def embedding_tied_rnn_seq2seq(encoder_inputs,
411                               decoder_inputs,
412                               cell,
413                               num_symbols,
414                               embedding_size,
415                               num_decoder_symbols=None,
416                               output_projection=None,
417                               feed_previous=False,
418                               dtype=None,
419                               scope=None):
420  """Embedding RNN sequence-to-sequence model with tied (shared) parameters.
421
422  This model first embeds encoder_inputs by a newly created embedding (of shape
423  [num_symbols x input_size]). Then it runs an RNN to encode embedded
424  encoder_inputs into a state vector. Next, it embeds decoder_inputs using
425  the same embedding. Then it runs RNN decoder, initialized with the last
426  encoder state, on embedded decoder_inputs. The decoder output is over symbols
427  from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it
428  is over 0 to num_symbols - 1.
429
430  Args:
431    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
432    decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
433    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
434    num_symbols: Integer; number of symbols for both encoder and decoder.
435    embedding_size: Integer, the length of the embedding vector for each symbol.
436    num_decoder_symbols: Integer; number of output symbols for decoder. If
437      provided, the decoder output is over symbols 0 to num_decoder_symbols - 1.
438      Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that
439      this assumes that the vocabulary is set up such that the first
440      num_decoder_symbols of num_symbols are part of decoding.
441    output_projection: None or a pair (W, B) of output projection weights and
442      biases; W has shape [output_size x num_symbols] and B has
443      shape [num_symbols]; if provided and feed_previous=True, each
444      fed previous output will first be multiplied by W and added B.
445    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
446      of decoder_inputs will be used (the "GO" symbol), and all other decoder
447      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
448      If False, decoder_inputs are used as given (the standard decoder case).
449    dtype: The dtype to use for the initial RNN states (default: tf.float32).
450    scope: VariableScope for the created subgraph; defaults to
451      "embedding_tied_rnn_seq2seq".
452
453  Returns:
454    A tuple of the form (outputs, state), where:
455      outputs: A list of the same length as decoder_inputs of 2D Tensors with
456        shape [batch_size x output_symbols] containing the generated
457        outputs where output_symbols = num_decoder_symbols if
458        num_decoder_symbols is not None otherwise output_symbols = num_symbols.
459      state: The state of each decoder cell at the final time-step.
460        It is a 2D Tensor of shape [batch_size x cell.state_size].
461
462  Raises:
463    ValueError: When output_projection has the wrong shape.
464  """
465  with variable_scope.variable_scope(
466      scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope:
467    dtype = scope.dtype
468
469    if output_projection is not None:
470      proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
471      proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
472      proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
473      proj_biases.get_shape().assert_is_compatible_with([num_symbols])
474
475    embedding = variable_scope.get_variable(
476        "embedding", [num_symbols, embedding_size], dtype=dtype)
477
478    emb_encoder_inputs = [
479        embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs
480    ]
481    emb_decoder_inputs = [
482        embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs
483    ]
484
485    output_symbols = num_symbols
486    if num_decoder_symbols is not None:
487      output_symbols = num_decoder_symbols
488    if output_projection is None:
489      cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols)
490
491    if isinstance(feed_previous, bool):
492      loop_function = _extract_argmax_and_embed(embedding, output_projection,
493                                                True) if feed_previous else None
494      return tied_rnn_seq2seq(
495          emb_encoder_inputs,
496          emb_decoder_inputs,
497          cell,
498          loop_function=loop_function,
499          dtype=dtype)
500
501    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
502    def decoder(feed_previous_bool):
503      loop_function = _extract_argmax_and_embed(
504          embedding, output_projection, False) if feed_previous_bool else None
505      reuse = None if feed_previous_bool else True
506      with variable_scope.variable_scope(
507          variable_scope.get_variable_scope(), reuse=reuse):
508        outputs, state = tied_rnn_seq2seq(
509            emb_encoder_inputs,
510            emb_decoder_inputs,
511            cell,
512            loop_function=loop_function,
513            dtype=dtype)
514        state_list = [state]
515        if nest.is_sequence(state):
516          state_list = nest.flatten(state)
517        return outputs + state_list
518
519    outputs_and_state = control_flow_ops.cond(feed_previous,
520                                              lambda: decoder(True),
521                                              lambda: decoder(False))
522    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
523    state_list = outputs_and_state[outputs_len:]
524    state = state_list[0]
525    # Calculate zero-state to know it's structure.
526    static_batch_size = encoder_inputs[0].get_shape()[0]
527    for inp in encoder_inputs[1:]:
528      static_batch_size.merge_with(inp.get_shape()[0])
529    batch_size = static_batch_size.value
530    if batch_size is None:
531      batch_size = array_ops.shape(encoder_inputs[0])[0]
532    zero_state = cell.zero_state(batch_size, dtype)
533    if nest.is_sequence(zero_state):
534      state = nest.pack_sequence_as(
535          structure=zero_state, flat_sequence=state_list)
536    return outputs_and_state[:outputs_len], state
537
538
539def attention_decoder(decoder_inputs,
540                      initial_state,
541                      attention_states,
542                      cell,
543                      output_size=None,
544                      num_heads=1,
545                      loop_function=None,
546                      dtype=None,
547                      scope=None,
548                      initial_state_attention=False):
549  """RNN decoder with attention for the sequence-to-sequence model.
550
551  In this context "attention" means that, during decoding, the RNN can look up
552  information in the additional tensor attention_states, and it does this by
553  focusing on a few entries from the tensor. This model has proven to yield
554  especially good results in a number of sequence-to-sequence tasks. This
555  implementation is based on http://arxiv.org/abs/1412.7449 (see below for
556  details). It is recommended for complex sequence-to-sequence tasks.
557
558  Args:
559    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
560    initial_state: 2D Tensor [batch_size x cell.state_size].
561    attention_states: 3D Tensor [batch_size x attn_length x attn_size].
562    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
563    output_size: Size of the output vectors; if None, we use cell.output_size.
564    num_heads: Number of attention heads that read from attention_states.
565    loop_function: If not None, this function will be applied to i-th output
566      in order to generate i+1-th input, and decoder_inputs will be ignored,
567      except for the first element ("GO" symbol). This can be used for decoding,
568      but also for training to emulate http://arxiv.org/abs/1506.03099.
569      Signature -- loop_function(prev, i) = next
570        * prev is a 2D Tensor of shape [batch_size x output_size],
571        * i is an integer, the step number (when advanced control is needed),
572        * next is a 2D Tensor of shape [batch_size x input_size].
573    dtype: The dtype to use for the RNN initial state (default: tf.float32).
574    scope: VariableScope for the created subgraph; default: "attention_decoder".
575    initial_state_attention: If False (default), initial attentions are zero.
576      If True, initialize the attentions from the initial state and attention
577      states -- useful when we wish to resume decoding from a previously
578      stored decoder state and attention states.
579
580  Returns:
581    A tuple of the form (outputs, state), where:
582      outputs: A list of the same length as decoder_inputs of 2D Tensors of
583        shape [batch_size x output_size]. These represent the generated outputs.
584        Output i is computed from input i (which is either the i-th element
585        of decoder_inputs or loop_function(output {i-1}, i)) as follows.
586        First, we run the cell on a combination of the input and previous
587        attention masks:
588          cell_output, new_state = cell(linear(input, prev_attn), prev_state).
589        Then, we calculate new attention masks:
590          new_attn = softmax(V^T * tanh(W * attention_states + U * new_state))
591        and then we calculate the output:
592          output = linear(cell_output, new_attn).
593      state: The state of each decoder cell the final time-step.
594        It is a 2D Tensor of shape [batch_size x cell.state_size].
595
596  Raises:
597    ValueError: when num_heads is not positive, there are no inputs, shapes
598      of attention_states are not set, or input size cannot be inferred
599      from the input.
600  """
601  if not decoder_inputs:
602    raise ValueError("Must provide at least 1 input to attention decoder.")
603  if num_heads < 1:
604    raise ValueError("With less than 1 heads, use a non-attention decoder.")
605  if attention_states.get_shape()[2].value is None:
606    raise ValueError("Shape[2] of attention_states must be known: %s" %
607                     attention_states.get_shape())
608  if output_size is None:
609    output_size = cell.output_size
610
611  with variable_scope.variable_scope(
612      scope or "attention_decoder", dtype=dtype) as scope:
613    dtype = scope.dtype
614
615    batch_size = array_ops.shape(decoder_inputs[0])[0]  # Needed for reshaping.
616    attn_length = attention_states.get_shape()[1].value
617    if attn_length is None:
618      attn_length = array_ops.shape(attention_states)[1]
619    attn_size = attention_states.get_shape()[2].value
620
621    # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
622    hidden = array_ops.reshape(attention_states,
623                               [-1, attn_length, 1, attn_size])
624    hidden_features = []
625    v = []
626    attention_vec_size = attn_size  # Size of query vectors for attention.
627    for a in xrange(num_heads):
628      k = variable_scope.get_variable(
629          "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size],
630          dtype=dtype)
631      hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
632      v.append(
633          variable_scope.get_variable(
634              "AttnV_%d" % a, [attention_vec_size], dtype=dtype))
635
636    state = initial_state
637
638    def attention(query):
639      """Put attention masks on hidden using hidden_features and query."""
640      ds = []  # Results of attention reads will be stored here.
641      if nest.is_sequence(query):  # If the query is a tuple, flatten it.
642        query_list = nest.flatten(query)
643        for q in query_list:  # Check that ndims == 2 if specified.
644          ndims = q.get_shape().ndims
645          if ndims:
646            assert ndims == 2
647        query = array_ops.concat(query_list, 1)
648      for a in xrange(num_heads):
649        with variable_scope.variable_scope("Attention_%d" % a):
650          y = Linear(query, attention_vec_size, True)(query)
651          y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
652          y = math_ops.cast(y, dtype)
653          # Attention mask is a softmax of v^T * tanh(...).
654          s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
655                                  [2, 3])
656          a = nn_ops.softmax(math_ops.cast(s, dtype=dtypes.float32))
657          # Now calculate the attention-weighted vector d.
658          a = math_ops.cast(a, dtype)
659          d = math_ops.reduce_sum(
660              array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
661          ds.append(array_ops.reshape(d, [-1, attn_size]))
662      return ds
663
664    outputs = []
665    prev = None
666    batch_attn_size = array_ops.stack([batch_size, attn_size])
667    attns = [
668        array_ops.zeros(
669            batch_attn_size, dtype=dtype) for _ in xrange(num_heads)
670    ]
671    for a in attns:  # Ensure the second shape of attention vectors is set.
672      a.set_shape([None, attn_size])
673    if initial_state_attention:
674      attns = attention(initial_state)
675    for i, inp in enumerate(decoder_inputs):
676      if i > 0:
677        variable_scope.get_variable_scope().reuse_variables()
678      # If loop_function is set, we use it instead of decoder_inputs.
679      if loop_function is not None and prev is not None:
680        with variable_scope.variable_scope("loop_function", reuse=True):
681          inp = loop_function(prev, i)
682      # Merge input and previous attentions into one vector of the right size.
683      input_size = inp.get_shape().with_rank(2)[1]
684      if input_size.value is None:
685        raise ValueError("Could not infer input size from input: %s" % inp.name)
686
687      inputs = [inp] + attns
688      inputs = [math_ops.cast(e, dtype) for e in inputs]
689      x = Linear(inputs, input_size, True)(inputs)
690      # Run the RNN.
691      cell_output, state = cell(x, state)
692      # Run the attention mechanism.
693      if i == 0 and initial_state_attention:
694        with variable_scope.variable_scope(
695            variable_scope.get_variable_scope(), reuse=True):
696          attns = attention(state)
697      else:
698        attns = attention(state)
699
700      with variable_scope.variable_scope("AttnOutputProjection"):
701        cell_output = math_ops.cast(cell_output, dtype)
702        inputs = [cell_output] + attns
703        output = Linear(inputs, output_size, True)(inputs)
704      if loop_function is not None:
705        prev = output
706      outputs.append(output)
707
708  return outputs, state
709
710
711def embedding_attention_decoder(decoder_inputs,
712                                initial_state,
713                                attention_states,
714                                cell,
715                                num_symbols,
716                                embedding_size,
717                                num_heads=1,
718                                output_size=None,
719                                output_projection=None,
720                                feed_previous=False,
721                                update_embedding_for_previous=True,
722                                dtype=None,
723                                scope=None,
724                                initial_state_attention=False):
725  """RNN decoder with embedding and attention and a pure-decoding option.
726
727  Args:
728    decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs).
729    initial_state: 2D Tensor [batch_size x cell.state_size].
730    attention_states: 3D Tensor [batch_size x attn_length x attn_size].
731    cell: tf.nn.rnn_cell.RNNCell defining the cell function.
732    num_symbols: Integer, how many symbols come into the embedding.
733    embedding_size: Integer, the length of the embedding vector for each symbol.
734    num_heads: Number of attention heads that read from attention_states.
735    output_size: Size of the output vectors; if None, use output_size.
736    output_projection: None or a pair (W, B) of output projection weights and
737      biases; W has shape [output_size x num_symbols] and B has shape
738      [num_symbols]; if provided and feed_previous=True, each fed previous
739      output will first be multiplied by W and added B.
740    feed_previous: Boolean; if True, only the first of decoder_inputs will be
741      used (the "GO" symbol), and all other decoder inputs will be generated by:
742        next = embedding_lookup(embedding, argmax(previous_output)),
743      In effect, this implements a greedy decoder. It can also be used
744      during training to emulate http://arxiv.org/abs/1506.03099.
745      If False, decoder_inputs are used as given (the standard decoder case).
746    update_embedding_for_previous: Boolean; if False and feed_previous=True,
747      only the embedding for the first symbol of decoder_inputs (the "GO"
748      symbol) will be updated by back propagation. Embeddings for the symbols
749      generated from the decoder itself remain unchanged. This parameter has
750      no effect if feed_previous=False.
751    dtype: The dtype to use for the RNN initial states (default: tf.float32).
752    scope: VariableScope for the created subgraph; defaults to
753      "embedding_attention_decoder".
754    initial_state_attention: If False (default), initial attentions are zero.
755      If True, initialize the attentions from the initial state and attention
756      states -- useful when we wish to resume decoding from a previously
757      stored decoder state and attention states.
758
759  Returns:
760    A tuple of the form (outputs, state), where:
761      outputs: A list of the same length as decoder_inputs of 2D Tensors with
762        shape [batch_size x output_size] containing the generated outputs.
763      state: The state of each decoder cell at the final time-step.
764        It is a 2D Tensor of shape [batch_size x cell.state_size].
765
766  Raises:
767    ValueError: When output_projection has the wrong shape.
768  """
769  if output_size is None:
770    output_size = cell.output_size
771  if output_projection is not None:
772    proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
773    proj_biases.get_shape().assert_is_compatible_with([num_symbols])
774
775  with variable_scope.variable_scope(
776      scope or "embedding_attention_decoder", dtype=dtype) as scope:
777
778    embedding = variable_scope.get_variable("embedding",
779                                            [num_symbols, embedding_size])
780    loop_function = _extract_argmax_and_embed(
781        embedding, output_projection,
782        update_embedding_for_previous) if feed_previous else None
783    emb_inp = [
784        embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs
785    ]
786    return attention_decoder(
787        emb_inp,
788        initial_state,
789        attention_states,
790        cell,
791        output_size=output_size,
792        num_heads=num_heads,
793        loop_function=loop_function,
794        initial_state_attention=initial_state_attention)
795
796
797def embedding_attention_seq2seq(encoder_inputs,
798                                decoder_inputs,
799                                cell,
800                                num_encoder_symbols,
801                                num_decoder_symbols,
802                                embedding_size,
803                                num_heads=1,
804                                output_projection=None,
805                                feed_previous=False,
806                                dtype=None,
807                                scope=None,
808                                initial_state_attention=False):
809  """Embedding sequence-to-sequence model with attention.
810
811  This model first embeds encoder_inputs by a newly created embedding (of shape
812  [num_encoder_symbols x input_size]). Then it runs an RNN to encode
813  embedded encoder_inputs into a state vector. It keeps the outputs of this
814  RNN at every step to use for attention later. Next, it embeds decoder_inputs
815  by another newly created embedding (of shape [num_decoder_symbols x
816  input_size]). Then it runs attention decoder, initialized with the last
817  encoder state, on embedded decoder_inputs and attending to encoder outputs.
818
819  Warning: when output_projection is None, the size of the attention vectors
820  and variables will be made proportional to num_decoder_symbols, can be large.
821
822  Args:
823    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
824    decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
825    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
826    num_encoder_symbols: Integer; number of symbols on the encoder side.
827    num_decoder_symbols: Integer; number of symbols on the decoder side.
828    embedding_size: Integer, the length of the embedding vector for each symbol.
829    num_heads: Number of attention heads that read from attention_states.
830    output_projection: None or a pair (W, B) of output projection weights and
831      biases; W has shape [output_size x num_decoder_symbols] and B has
832      shape [num_decoder_symbols]; if provided and feed_previous=True, each
833      fed previous output will first be multiplied by W and added B.
834    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
835      of decoder_inputs will be used (the "GO" symbol), and all other decoder
836      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
837      If False, decoder_inputs are used as given (the standard decoder case).
838    dtype: The dtype of the initial RNN state (default: tf.float32).
839    scope: VariableScope for the created subgraph; defaults to
840      "embedding_attention_seq2seq".
841    initial_state_attention: If False (default), initial attentions are zero.
842      If True, initialize the attentions from the initial state and attention
843      states.
844
845  Returns:
846    A tuple of the form (outputs, state), where:
847      outputs: A list of the same length as decoder_inputs of 2D Tensors with
848        shape [batch_size x num_decoder_symbols] containing the generated
849        outputs.
850      state: The state of each decoder cell at the final time-step.
851        It is a 2D Tensor of shape [batch_size x cell.state_size].
852  """
853  with variable_scope.variable_scope(
854      scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
855    dtype = scope.dtype
856    # Encoder.
857    encoder_cell = copy.deepcopy(cell)
858    encoder_cell = core_rnn_cell.EmbeddingWrapper(
859        encoder_cell,
860        embedding_classes=num_encoder_symbols,
861        embedding_size=embedding_size)
862    encoder_outputs, encoder_state = rnn.static_rnn(
863        encoder_cell, encoder_inputs, dtype=dtype)
864
865    # First calculate a concatenation of encoder outputs to put attention on.
866    top_states = [
867        array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs
868    ]
869    attention_states = array_ops.concat(top_states, 1)
870
871    # Decoder.
872    output_size = None
873    if output_projection is None:
874      cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)
875      output_size = num_decoder_symbols
876
877    if isinstance(feed_previous, bool):
878      return embedding_attention_decoder(
879          decoder_inputs,
880          encoder_state,
881          attention_states,
882          cell,
883          num_decoder_symbols,
884          embedding_size,
885          num_heads=num_heads,
886          output_size=output_size,
887          output_projection=output_projection,
888          feed_previous=feed_previous,
889          initial_state_attention=initial_state_attention)
890
891    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
892    def decoder(feed_previous_bool):
893      reuse = None if feed_previous_bool else True
894      with variable_scope.variable_scope(
895          variable_scope.get_variable_scope(), reuse=reuse):
896        outputs, state = embedding_attention_decoder(
897            decoder_inputs,
898            encoder_state,
899            attention_states,
900            cell,
901            num_decoder_symbols,
902            embedding_size,
903            num_heads=num_heads,
904            output_size=output_size,
905            output_projection=output_projection,
906            feed_previous=feed_previous_bool,
907            update_embedding_for_previous=False,
908            initial_state_attention=initial_state_attention)
909        state_list = [state]
910        if nest.is_sequence(state):
911          state_list = nest.flatten(state)
912        return outputs + state_list
913
914    outputs_and_state = control_flow_ops.cond(feed_previous,
915                                              lambda: decoder(True),
916                                              lambda: decoder(False))
917    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
918    state_list = outputs_and_state[outputs_len:]
919    state = state_list[0]
920    if nest.is_sequence(encoder_state):
921      state = nest.pack_sequence_as(
922          structure=encoder_state, flat_sequence=state_list)
923    return outputs_and_state[:outputs_len], state
924
925
926def one2many_rnn_seq2seq(encoder_inputs,
927                         decoder_inputs_dict,
928                         enc_cell,
929                         dec_cells_dict,
930                         num_encoder_symbols,
931                         num_decoder_symbols_dict,
932                         embedding_size,
933                         feed_previous=False,
934                         dtype=None,
935                         scope=None):
936  """One-to-many RNN sequence-to-sequence model (multi-task).
937
938  This is a multi-task sequence-to-sequence model with one encoder and multiple
939  decoders. Reference to multi-task sequence-to-sequence learning can be found
940  here: http://arxiv.org/abs/1511.06114
941
942  Args:
943    encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
944    decoder_inputs_dict: A dictionary mapping decoder name (string) to
945      the corresponding decoder_inputs; each decoder_inputs is a list of 1D
946      Tensors of shape [batch_size]; num_decoders is defined as
947      len(decoder_inputs_dict).
948    enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and
949      size.
950    dec_cells_dict: A dictionary mapping encoder name (string) to an
951      instance of tf.nn.rnn_cell.RNNCell.
952    num_encoder_symbols: Integer; number of symbols on the encoder side.
953    num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an
954      integer specifying number of symbols for the corresponding decoder;
955      len(num_decoder_symbols_dict) must be equal to num_decoders.
956    embedding_size: Integer, the length of the embedding vector for each symbol.
957    feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of
958      decoder_inputs will be used (the "GO" symbol), and all other decoder
959      inputs will be taken from previous outputs (as in embedding_rnn_decoder).
960      If False, decoder_inputs are used as given (the standard decoder case).
961    dtype: The dtype of the initial state for both the encoder and encoder
962      rnn cells (default: tf.float32).
963    scope: VariableScope for the created subgraph; defaults to
964      "one2many_rnn_seq2seq"
965
966  Returns:
967    A tuple of the form (outputs_dict, state_dict), where:
968      outputs_dict: A mapping from decoder name (string) to a list of the same
969        length as decoder_inputs_dict[name]; each element in the list is a 2D
970        Tensors with shape [batch_size x num_decoder_symbol_list[name]]
971        containing the generated outputs.
972      state_dict: A mapping from decoder name (string) to the final state of the
973        corresponding decoder RNN; it is a 2D Tensor of shape
974        [batch_size x cell.state_size].
975
976  Raises:
977    TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell.
978    ValueError: if len(dec_cells) != len(decoder_inputs_dict).
979  """
980  outputs_dict = {}
981  state_dict = {}
982
983  if not isinstance(enc_cell, rnn_cell_impl.RNNCell):
984    raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell))
985  if set(dec_cells_dict) != set(decoder_inputs_dict):
986    raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict")
987  for dec_cell in dec_cells_dict.values():
988    if not isinstance(dec_cell, rnn_cell_impl.RNNCell):
989      raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell))
990
991  with variable_scope.variable_scope(
992      scope or "one2many_rnn_seq2seq", dtype=dtype) as scope:
993    dtype = scope.dtype
994
995    # Encoder.
996    enc_cell = core_rnn_cell.EmbeddingWrapper(
997        enc_cell,
998        embedding_classes=num_encoder_symbols,
999        embedding_size=embedding_size)
1000    _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
1001
1002    # Decoder.
1003    for name, decoder_inputs in decoder_inputs_dict.items():
1004      num_decoder_symbols = num_decoder_symbols_dict[name]
1005      dec_cell = dec_cells_dict[name]
1006
1007      with variable_scope.variable_scope("one2many_decoder_" + str(
1008          name)) as scope:
1009        dec_cell = core_rnn_cell.OutputProjectionWrapper(
1010            dec_cell, num_decoder_symbols)
1011        if isinstance(feed_previous, bool):
1012          outputs, state = embedding_rnn_decoder(
1013              decoder_inputs,
1014              encoder_state,
1015              dec_cell,
1016              num_decoder_symbols,
1017              embedding_size,
1018              feed_previous=feed_previous)
1019        else:
1020          # If feed_previous is a Tensor, we construct 2 graphs and use cond.
1021          def filled_embedding_rnn_decoder(feed_previous):
1022            """The current decoder with a fixed feed_previous parameter."""
1023            # pylint: disable=cell-var-from-loop
1024            reuse = None if feed_previous else True
1025            vs = variable_scope.get_variable_scope()
1026            with variable_scope.variable_scope(vs, reuse=reuse):
1027              outputs, state = embedding_rnn_decoder(
1028                  decoder_inputs,
1029                  encoder_state,
1030                  dec_cell,
1031                  num_decoder_symbols,
1032                  embedding_size,
1033                  feed_previous=feed_previous)
1034            # pylint: enable=cell-var-from-loop
1035            state_list = [state]
1036            if nest.is_sequence(state):
1037              state_list = nest.flatten(state)
1038            return outputs + state_list
1039
1040          outputs_and_state = control_flow_ops.cond(
1041              feed_previous, lambda: filled_embedding_rnn_decoder(True),
1042              lambda: filled_embedding_rnn_decoder(False))
1043          # Outputs length is the same as for decoder inputs.
1044          outputs_len = len(decoder_inputs)
1045          outputs = outputs_and_state[:outputs_len]
1046          state_list = outputs_and_state[outputs_len:]
1047          state = state_list[0]
1048          if nest.is_sequence(encoder_state):
1049            state = nest.pack_sequence_as(
1050                structure=encoder_state, flat_sequence=state_list)
1051      outputs_dict[name] = outputs
1052      state_dict[name] = state
1053
1054  return outputs_dict, state_dict
1055
1056
1057def sequence_loss_by_example(logits,
1058                             targets,
1059                             weights,
1060                             average_across_timesteps=True,
1061                             softmax_loss_function=None,
1062                             name=None):
1063  """Weighted cross-entropy loss for a sequence of logits (per example).
1064
1065  Args:
1066    logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
1067    targets: List of 1D batch-sized int32 Tensors of the same length as logits.
1068    weights: List of 1D batch-sized float-Tensors of the same length as logits.
1069    average_across_timesteps: If set, divide the returned cost by the total
1070      label weight.
1071    softmax_loss_function: Function (labels, logits) -> loss-batch
1072      to be used instead of the standard softmax (the default if this is None).
1073      **Note that to avoid confusion, it is required for the function to accept
1074      named arguments.**
1075    name: Optional name for this operation, default: "sequence_loss_by_example".
1076
1077  Returns:
1078    1D batch-sized float Tensor: The log-perplexity for each sequence.
1079
1080  Raises:
1081    ValueError: If len(logits) is different from len(targets) or len(weights).
1082  """
1083  if len(targets) != len(logits) or len(weights) != len(logits):
1084    raise ValueError("Lengths of logits, weights, and targets must be the same "
1085                     "%d, %d, %d." % (len(logits), len(weights), len(targets)))
1086  with ops.name_scope(name, "sequence_loss_by_example",
1087                      logits + targets + weights):
1088    log_perp_list = []
1089    for logit, target, weight in zip(logits, targets, weights):
1090      if softmax_loss_function is None:
1091        # TODO(irving,ebrevdo): This reshape is needed because
1092        # sequence_loss_by_example is called with scalars sometimes, which
1093        # violates our general scalar strictness policy.
1094        target = array_ops.reshape(target, [-1])
1095        crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
1096            labels=target, logits=logit)
1097      else:
1098        crossent = softmax_loss_function(labels=target, logits=logit)
1099      log_perp_list.append(crossent * weight)
1100    log_perps = math_ops.add_n(log_perp_list)
1101    if average_across_timesteps:
1102      total_size = math_ops.add_n(weights)
1103      total_size += 1e-12  # Just to avoid division by 0 for all-0 weights.
1104      log_perps /= total_size
1105  return log_perps
1106
1107
1108def sequence_loss(logits,
1109                  targets,
1110                  weights,
1111                  average_across_timesteps=True,
1112                  average_across_batch=True,
1113                  softmax_loss_function=None,
1114                  name=None):
1115  """Weighted cross-entropy loss for a sequence of logits, batch-collapsed.
1116
1117  Args:
1118    logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
1119    targets: List of 1D batch-sized int32 Tensors of the same length as logits.
1120    weights: List of 1D batch-sized float-Tensors of the same length as logits.
1121    average_across_timesteps: If set, divide the returned cost by the total
1122      label weight.
1123    average_across_batch: If set, divide the returned cost by the batch size.
1124    softmax_loss_function: Function (labels, logits) -> loss-batch
1125      to be used instead of the standard softmax (the default if this is None).
1126      **Note that to avoid confusion, it is required for the function to accept
1127      named arguments.**
1128    name: Optional name for this operation, defaults to "sequence_loss".
1129
1130  Returns:
1131    A scalar float Tensor: The average log-perplexity per symbol (weighted).
1132
1133  Raises:
1134    ValueError: If len(logits) is different from len(targets) or len(weights).
1135  """
1136  with ops.name_scope(name, "sequence_loss", logits + targets + weights):
1137    cost = math_ops.reduce_sum(
1138        sequence_loss_by_example(
1139            logits,
1140            targets,
1141            weights,
1142            average_across_timesteps=average_across_timesteps,
1143            softmax_loss_function=softmax_loss_function))
1144    if average_across_batch:
1145      batch_size = array_ops.shape(targets[0])[0]
1146      return cost / math_ops.cast(batch_size, cost.dtype)
1147    else:
1148      return cost
1149
1150
1151def model_with_buckets(encoder_inputs,
1152                       decoder_inputs,
1153                       targets,
1154                       weights,
1155                       buckets,
1156                       seq2seq,
1157                       softmax_loss_function=None,
1158                       per_example_loss=False,
1159                       name=None):
1160  """Create a sequence-to-sequence model with support for bucketing.
1161
1162  The seq2seq argument is a function that defines a sequence-to-sequence model,
1163  e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(
1164      x, y, rnn_cell.GRUCell(24))
1165
1166  Args:
1167    encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
1168    decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input.
1169    targets: A list of 1D batch-sized int32 Tensors (desired output sequence).
1170    weights: List of 1D batch-sized float-Tensors to weight the targets.
1171    buckets: A list of pairs of (input size, output size) for each bucket.
1172    seq2seq: A sequence-to-sequence model function; it takes 2 input that
1173      agree with encoder_inputs and decoder_inputs, and returns a pair
1174      consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
1175    softmax_loss_function: Function (labels, logits) -> loss-batch
1176      to be used instead of the standard softmax (the default if this is None).
1177      **Note that to avoid confusion, it is required for the function to accept
1178      named arguments.**
1179    per_example_loss: Boolean. If set, the returned loss will be a batch-sized
1180      tensor of losses for each sequence in the batch. If unset, it will be
1181      a scalar with the averaged loss from all examples.
1182    name: Optional name for this operation, defaults to "model_with_buckets".
1183
1184  Returns:
1185    A tuple of the form (outputs, losses), where:
1186      outputs: The outputs for each bucket. Its j'th element consists of a list
1187        of 2D Tensors. The shape of output tensors can be either
1188        [batch_size x output_size] or [batch_size x num_decoder_symbols]
1189        depending on the seq2seq model used.
1190      losses: List of scalar Tensors, representing losses for each bucket, or,
1191        if per_example_loss is set, a list of 1D batch-sized float Tensors.
1192
1193  Raises:
1194    ValueError: If length of encoder_inputs, targets, or weights is smaller
1195      than the largest (last) bucket.
1196  """
1197  if len(encoder_inputs) < buckets[-1][0]:
1198    raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
1199                     "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
1200  if len(targets) < buckets[-1][1]:
1201    raise ValueError("Length of targets (%d) must be at least that of last "
1202                     "bucket (%d)." % (len(targets), buckets[-1][1]))
1203  if len(weights) < buckets[-1][1]:
1204    raise ValueError("Length of weights (%d) must be at least that of last "
1205                     "bucket (%d)." % (len(weights), buckets[-1][1]))
1206
1207  all_inputs = encoder_inputs + decoder_inputs + targets + weights
1208  losses = []
1209  outputs = []
1210  with ops.name_scope(name, "model_with_buckets", all_inputs):
1211    for j, bucket in enumerate(buckets):
1212      with variable_scope.variable_scope(
1213          variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
1214        bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
1215                                    decoder_inputs[:bucket[1]])
1216        outputs.append(bucket_outputs)
1217        if per_example_loss:
1218          losses.append(
1219              sequence_loss_by_example(
1220                  outputs[-1],
1221                  targets[:bucket[1]],
1222                  weights[:bucket[1]],
1223                  softmax_loss_function=softmax_loss_function))
1224        else:
1225          losses.append(
1226              sequence_loss(
1227                  outputs[-1],
1228                  targets[:bucket[1]],
1229                  weights[:bucket[1]],
1230                  softmax_loss_function=softmax_loss_function))
1231
1232  return outputs, losses
1233