1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for creating sequence-to-sequence models in TensorFlow. 16 17Sequence-to-sequence recurrent neural networks can learn complex functions 18that map input sequences to output sequences. These models yield very good 19results on a number of tasks, such as speech recognition, parsing, machine 20translation, or even constructing automated replies to emails. 21 22Before using this module, it is recommended to read the TensorFlow tutorial 23on sequence-to-sequence models. It explains the basic concepts of this module 24and shows an end-to-end example of how to build a translation model. 25 https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 26 27Here is an overview of functions available in this module. They all use 28a very similar interface, so after reading the above tutorial and using 29one of them, others should be easy to substitute. 30 31* Full sequence-to-sequence models. 32 - basic_rnn_seq2seq: The most basic RNN-RNN model. 33 - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights. 34 - embedding_rnn_seq2seq: The basic model with input embedding. 35 - embedding_tied_rnn_seq2seq: The tied model with input embedding. 36 - embedding_attention_seq2seq: Advanced model with input embedding and 37 the neural attention mechanism; recommended for complex tasks. 38 39* Multi-task sequence-to-sequence models. 40 - one2many_rnn_seq2seq: The embedding model with multiple decoders. 41 42* Decoders (when you write your own encoder, you can use these to decode; 43 e.g., if you want to write a model that generates captions for images). 44 - rnn_decoder: The basic decoder based on a pure RNN. 45 - attention_decoder: A decoder that uses the attention mechanism. 46 47* Losses. 48 - sequence_loss: Loss for a sequence model returning average log-perplexity. 49 - sequence_loss_by_example: As above, but not averaging over all examples. 50 51* model_with_buckets: A convenience function to create models with bucketing 52 (see the tutorial above for an explanation of why and how to use it). 53""" 54 55from __future__ import absolute_import 56from __future__ import division 57from __future__ import print_function 58 59import copy 60 61# We disable pylint because we need python3 compatibility. 62from six.moves import xrange # pylint: disable=redefined-builtin 63from six.moves import zip # pylint: disable=redefined-builtin 64 65from tensorflow.contrib.rnn.python.ops import core_rnn_cell 66from tensorflow.python.framework import dtypes 67from tensorflow.python.framework import ops 68from tensorflow.python.ops import array_ops 69from tensorflow.python.ops import control_flow_ops 70from tensorflow.python.ops import embedding_ops 71from tensorflow.python.ops import math_ops 72from tensorflow.python.ops import nn_ops 73from tensorflow.python.ops import rnn 74from tensorflow.python.ops import rnn_cell_impl 75from tensorflow.python.ops import variable_scope 76from tensorflow.python.util import nest 77 78# TODO(ebrevdo): Remove once _linear is fully deprecated. 79Linear = core_rnn_cell._Linear # pylint: disable=protected-access,invalid-name 80 81 82def _extract_argmax_and_embed(embedding, 83 output_projection=None, 84 update_embedding=True): 85 """Get a loop_function that extracts the previous symbol and embeds it. 86 87 Args: 88 embedding: embedding tensor for symbols. 89 output_projection: None or a pair (W, B). If provided, each fed previous 90 output will first be multiplied by W and added B. 91 update_embedding: Boolean; if False, the gradients will not propagate 92 through the embeddings. 93 94 Returns: 95 A loop function. 96 """ 97 98 def loop_function(prev, _): 99 if output_projection is not None: 100 prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1]) 101 prev_symbol = math_ops.argmax(prev, 1) 102 # Note that gradients will not propagate through the second parameter of 103 # embedding_lookup. 104 emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 105 if not update_embedding: 106 emb_prev = array_ops.stop_gradient(emb_prev) 107 return emb_prev 108 109 return loop_function 110 111 112def rnn_decoder(decoder_inputs, 113 initial_state, 114 cell, 115 loop_function=None, 116 scope=None): 117 """RNN decoder for the sequence-to-sequence model. 118 119 Args: 120 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 121 initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 122 cell: rnn_cell.RNNCell defining the cell function and size. 123 loop_function: If not None, this function will be applied to the i-th output 124 in order to generate the i+1-st input, and decoder_inputs will be ignored, 125 except for the first element ("GO" symbol). This can be used for decoding, 126 but also for training to emulate http://arxiv.org/abs/1506.03099. 127 Signature -- loop_function(prev, i) = next 128 * prev is a 2D Tensor of shape [batch_size x output_size], 129 * i is an integer, the step number (when advanced control is needed), 130 * next is a 2D Tensor of shape [batch_size x input_size]. 131 scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 132 133 Returns: 134 A tuple of the form (outputs, state), where: 135 outputs: A list of the same length as decoder_inputs of 2D Tensors with 136 shape [batch_size x output_size] containing generated outputs. 137 state: The state of each cell at the final time-step. 138 It is a 2D Tensor of shape [batch_size x cell.state_size]. 139 (Note that in some cases, like basic RNN cell or GRU cell, outputs and 140 states can be the same. They are different for LSTM cells though.) 141 """ 142 with variable_scope.variable_scope(scope or "rnn_decoder"): 143 state = initial_state 144 outputs = [] 145 prev = None 146 for i, inp in enumerate(decoder_inputs): 147 if loop_function is not None and prev is not None: 148 with variable_scope.variable_scope("loop_function", reuse=True): 149 inp = loop_function(prev, i) 150 if i > 0: 151 variable_scope.get_variable_scope().reuse_variables() 152 output, state = cell(inp, state) 153 outputs.append(output) 154 if loop_function is not None: 155 prev = output 156 return outputs, state 157 158 159def basic_rnn_seq2seq(encoder_inputs, 160 decoder_inputs, 161 cell, 162 dtype=dtypes.float32, 163 scope=None): 164 """Basic RNN sequence-to-sequence model. 165 166 This model first runs an RNN to encode encoder_inputs into a state vector, 167 then runs decoder, initialized with the last encoder state, on decoder_inputs. 168 Encoder and decoder use the same RNN cell type, but don't share parameters. 169 170 Args: 171 encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 172 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 173 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 174 dtype: The dtype of the initial state of the RNN cell (default: tf.float32). 175 scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". 176 177 Returns: 178 A tuple of the form (outputs, state), where: 179 outputs: A list of the same length as decoder_inputs of 2D Tensors with 180 shape [batch_size x output_size] containing the generated outputs. 181 state: The state of each decoder cell in the final time-step. 182 It is a 2D Tensor of shape [batch_size x cell.state_size]. 183 """ 184 with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): 185 enc_cell = copy.deepcopy(cell) 186 _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) 187 return rnn_decoder(decoder_inputs, enc_state, cell) 188 189 190def tied_rnn_seq2seq(encoder_inputs, 191 decoder_inputs, 192 cell, 193 loop_function=None, 194 dtype=dtypes.float32, 195 scope=None): 196 """RNN sequence-to-sequence model with tied encoder and decoder parameters. 197 198 This model first runs an RNN to encode encoder_inputs into a state vector, and 199 then runs decoder, initialized with the last encoder state, on decoder_inputs. 200 Encoder and decoder use the same RNN cell and share parameters. 201 202 Args: 203 encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 204 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 205 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 206 loop_function: If not None, this function will be applied to i-th output 207 in order to generate i+1-th input, and decoder_inputs will be ignored, 208 except for the first element ("GO" symbol), see rnn_decoder for details. 209 dtype: The dtype of the initial state of the rnn cell (default: tf.float32). 210 scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". 211 212 Returns: 213 A tuple of the form (outputs, state), where: 214 outputs: A list of the same length as decoder_inputs of 2D Tensors with 215 shape [batch_size x output_size] containing the generated outputs. 216 state: The state of each decoder cell in each time-step. This is a list 217 with length len(decoder_inputs) -- one item for each time-step. 218 It is a 2D Tensor of shape [batch_size x cell.state_size]. 219 """ 220 with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): 221 scope = scope or "tied_rnn_seq2seq" 222 _, enc_state = rnn.static_rnn( 223 cell, encoder_inputs, dtype=dtype, scope=scope) 224 variable_scope.get_variable_scope().reuse_variables() 225 return rnn_decoder( 226 decoder_inputs, 227 enc_state, 228 cell, 229 loop_function=loop_function, 230 scope=scope) 231 232 233def embedding_rnn_decoder(decoder_inputs, 234 initial_state, 235 cell, 236 num_symbols, 237 embedding_size, 238 output_projection=None, 239 feed_previous=False, 240 update_embedding_for_previous=True, 241 scope=None): 242 """RNN decoder with embedding and a pure-decoding option. 243 244 Args: 245 decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 246 initial_state: 2D Tensor [batch_size x cell.state_size]. 247 cell: tf.nn.rnn_cell.RNNCell defining the cell function. 248 num_symbols: Integer, how many symbols come into the embedding. 249 embedding_size: Integer, the length of the embedding vector for each symbol. 250 output_projection: None or a pair (W, B) of output projection weights and 251 biases; W has shape [output_size x num_symbols] and B has 252 shape [num_symbols]; if provided and feed_previous=True, each fed 253 previous output will first be multiplied by W and added B. 254 feed_previous: Boolean; if True, only the first of decoder_inputs will be 255 used (the "GO" symbol), and all other decoder inputs will be generated by: 256 next = embedding_lookup(embedding, argmax(previous_output)), 257 In effect, this implements a greedy decoder. It can also be used 258 during training to emulate http://arxiv.org/abs/1506.03099. 259 If False, decoder_inputs are used as given (the standard decoder case). 260 update_embedding_for_previous: Boolean; if False and feed_previous=True, 261 only the embedding for the first symbol of decoder_inputs (the "GO" 262 symbol) will be updated by back propagation. Embeddings for the symbols 263 generated from the decoder itself remain unchanged. This parameter has 264 no effect if feed_previous=False. 265 scope: VariableScope for the created subgraph; defaults to 266 "embedding_rnn_decoder". 267 268 Returns: 269 A tuple of the form (outputs, state), where: 270 outputs: A list of the same length as decoder_inputs of 2D Tensors. The 271 output is of shape [batch_size x cell.output_size] when 272 output_projection is not None (and represents the dense representation 273 of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 274 when output_projection is None. 275 state: The state of each decoder cell in each time-step. This is a list 276 with length len(decoder_inputs) -- one item for each time-step. 277 It is a 2D Tensor of shape [batch_size x cell.state_size]. 278 279 Raises: 280 ValueError: When output_projection has the wrong shape. 281 """ 282 with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope: 283 if output_projection is not None: 284 dtype = scope.dtype 285 proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 286 proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 287 proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 288 proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 289 290 embedding = variable_scope.get_variable("embedding", 291 [num_symbols, embedding_size]) 292 loop_function = _extract_argmax_and_embed( 293 embedding, output_projection, 294 update_embedding_for_previous) if feed_previous else None 295 emb_inp = (embedding_ops.embedding_lookup(embedding, i) 296 for i in decoder_inputs) 297 return rnn_decoder( 298 emb_inp, initial_state, cell, loop_function=loop_function) 299 300 301def embedding_rnn_seq2seq(encoder_inputs, 302 decoder_inputs, 303 cell, 304 num_encoder_symbols, 305 num_decoder_symbols, 306 embedding_size, 307 output_projection=None, 308 feed_previous=False, 309 dtype=None, 310 scope=None): 311 """Embedding RNN sequence-to-sequence model. 312 313 This model first embeds encoder_inputs by a newly created embedding (of shape 314 [num_encoder_symbols x input_size]). Then it runs an RNN to encode 315 embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs 316 by another newly created embedding (of shape [num_decoder_symbols x 317 input_size]). Then it runs RNN decoder, initialized with the last 318 encoder state, on embedded decoder_inputs. 319 320 Args: 321 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 322 decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 323 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 324 num_encoder_symbols: Integer; number of symbols on the encoder side. 325 num_decoder_symbols: Integer; number of symbols on the decoder side. 326 embedding_size: Integer, the length of the embedding vector for each symbol. 327 output_projection: None or a pair (W, B) of output projection weights and 328 biases; W has shape [output_size x num_decoder_symbols] and B has 329 shape [num_decoder_symbols]; if provided and feed_previous=True, each 330 fed previous output will first be multiplied by W and added B. 331 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 332 of decoder_inputs will be used (the "GO" symbol), and all other decoder 333 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 334 If False, decoder_inputs are used as given (the standard decoder case). 335 dtype: The dtype of the initial state for both the encoder and encoder 336 rnn cells (default: tf.float32). 337 scope: VariableScope for the created subgraph; defaults to 338 "embedding_rnn_seq2seq" 339 340 Returns: 341 A tuple of the form (outputs, state), where: 342 outputs: A list of the same length as decoder_inputs of 2D Tensors. The 343 output is of shape [batch_size x cell.output_size] when 344 output_projection is not None (and represents the dense representation 345 of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 346 when output_projection is None. 347 state: The state of each decoder cell in each time-step. This is a list 348 with length len(decoder_inputs) -- one item for each time-step. 349 It is a 2D Tensor of shape [batch_size x cell.state_size]. 350 """ 351 with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope: 352 if dtype is not None: 353 scope.set_dtype(dtype) 354 else: 355 dtype = scope.dtype 356 357 # Encoder. 358 encoder_cell = copy.deepcopy(cell) 359 encoder_cell = core_rnn_cell.EmbeddingWrapper( 360 encoder_cell, 361 embedding_classes=num_encoder_symbols, 362 embedding_size=embedding_size) 363 _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) 364 365 # Decoder. 366 if output_projection is None: 367 cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 368 369 if isinstance(feed_previous, bool): 370 return embedding_rnn_decoder( 371 decoder_inputs, 372 encoder_state, 373 cell, 374 num_decoder_symbols, 375 embedding_size, 376 output_projection=output_projection, 377 feed_previous=feed_previous) 378 379 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 380 def decoder(feed_previous_bool): 381 reuse = None if feed_previous_bool else True 382 with variable_scope.variable_scope( 383 variable_scope.get_variable_scope(), reuse=reuse): 384 outputs, state = embedding_rnn_decoder( 385 decoder_inputs, 386 encoder_state, 387 cell, 388 num_decoder_symbols, 389 embedding_size, 390 output_projection=output_projection, 391 feed_previous=feed_previous_bool, 392 update_embedding_for_previous=False) 393 state_list = [state] 394 if nest.is_sequence(state): 395 state_list = nest.flatten(state) 396 return outputs + state_list 397 398 outputs_and_state = control_flow_ops.cond(feed_previous, 399 lambda: decoder(True), 400 lambda: decoder(False)) 401 outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 402 state_list = outputs_and_state[outputs_len:] 403 state = state_list[0] 404 if nest.is_sequence(encoder_state): 405 state = nest.pack_sequence_as( 406 structure=encoder_state, flat_sequence=state_list) 407 return outputs_and_state[:outputs_len], state 408 409 410def embedding_tied_rnn_seq2seq(encoder_inputs, 411 decoder_inputs, 412 cell, 413 num_symbols, 414 embedding_size, 415 num_decoder_symbols=None, 416 output_projection=None, 417 feed_previous=False, 418 dtype=None, 419 scope=None): 420 """Embedding RNN sequence-to-sequence model with tied (shared) parameters. 421 422 This model first embeds encoder_inputs by a newly created embedding (of shape 423 [num_symbols x input_size]). Then it runs an RNN to encode embedded 424 encoder_inputs into a state vector. Next, it embeds decoder_inputs using 425 the same embedding. Then it runs RNN decoder, initialized with the last 426 encoder state, on embedded decoder_inputs. The decoder output is over symbols 427 from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it 428 is over 0 to num_symbols - 1. 429 430 Args: 431 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 432 decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 433 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 434 num_symbols: Integer; number of symbols for both encoder and decoder. 435 embedding_size: Integer, the length of the embedding vector for each symbol. 436 num_decoder_symbols: Integer; number of output symbols for decoder. If 437 provided, the decoder output is over symbols 0 to num_decoder_symbols - 1. 438 Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that 439 this assumes that the vocabulary is set up such that the first 440 num_decoder_symbols of num_symbols are part of decoding. 441 output_projection: None or a pair (W, B) of output projection weights and 442 biases; W has shape [output_size x num_symbols] and B has 443 shape [num_symbols]; if provided and feed_previous=True, each 444 fed previous output will first be multiplied by W and added B. 445 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 446 of decoder_inputs will be used (the "GO" symbol), and all other decoder 447 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 448 If False, decoder_inputs are used as given (the standard decoder case). 449 dtype: The dtype to use for the initial RNN states (default: tf.float32). 450 scope: VariableScope for the created subgraph; defaults to 451 "embedding_tied_rnn_seq2seq". 452 453 Returns: 454 A tuple of the form (outputs, state), where: 455 outputs: A list of the same length as decoder_inputs of 2D Tensors with 456 shape [batch_size x output_symbols] containing the generated 457 outputs where output_symbols = num_decoder_symbols if 458 num_decoder_symbols is not None otherwise output_symbols = num_symbols. 459 state: The state of each decoder cell at the final time-step. 460 It is a 2D Tensor of shape [batch_size x cell.state_size]. 461 462 Raises: 463 ValueError: When output_projection has the wrong shape. 464 """ 465 with variable_scope.variable_scope( 466 scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope: 467 dtype = scope.dtype 468 469 if output_projection is not None: 470 proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 471 proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 472 proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 473 proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 474 475 embedding = variable_scope.get_variable( 476 "embedding", [num_symbols, embedding_size], dtype=dtype) 477 478 emb_encoder_inputs = [ 479 embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs 480 ] 481 emb_decoder_inputs = [ 482 embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs 483 ] 484 485 output_symbols = num_symbols 486 if num_decoder_symbols is not None: 487 output_symbols = num_decoder_symbols 488 if output_projection is None: 489 cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols) 490 491 if isinstance(feed_previous, bool): 492 loop_function = _extract_argmax_and_embed(embedding, output_projection, 493 True) if feed_previous else None 494 return tied_rnn_seq2seq( 495 emb_encoder_inputs, 496 emb_decoder_inputs, 497 cell, 498 loop_function=loop_function, 499 dtype=dtype) 500 501 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 502 def decoder(feed_previous_bool): 503 loop_function = _extract_argmax_and_embed( 504 embedding, output_projection, False) if feed_previous_bool else None 505 reuse = None if feed_previous_bool else True 506 with variable_scope.variable_scope( 507 variable_scope.get_variable_scope(), reuse=reuse): 508 outputs, state = tied_rnn_seq2seq( 509 emb_encoder_inputs, 510 emb_decoder_inputs, 511 cell, 512 loop_function=loop_function, 513 dtype=dtype) 514 state_list = [state] 515 if nest.is_sequence(state): 516 state_list = nest.flatten(state) 517 return outputs + state_list 518 519 outputs_and_state = control_flow_ops.cond(feed_previous, 520 lambda: decoder(True), 521 lambda: decoder(False)) 522 outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 523 state_list = outputs_and_state[outputs_len:] 524 state = state_list[0] 525 # Calculate zero-state to know it's structure. 526 static_batch_size = encoder_inputs[0].get_shape()[0] 527 for inp in encoder_inputs[1:]: 528 static_batch_size.merge_with(inp.get_shape()[0]) 529 batch_size = static_batch_size.value 530 if batch_size is None: 531 batch_size = array_ops.shape(encoder_inputs[0])[0] 532 zero_state = cell.zero_state(batch_size, dtype) 533 if nest.is_sequence(zero_state): 534 state = nest.pack_sequence_as( 535 structure=zero_state, flat_sequence=state_list) 536 return outputs_and_state[:outputs_len], state 537 538 539def attention_decoder(decoder_inputs, 540 initial_state, 541 attention_states, 542 cell, 543 output_size=None, 544 num_heads=1, 545 loop_function=None, 546 dtype=None, 547 scope=None, 548 initial_state_attention=False): 549 """RNN decoder with attention for the sequence-to-sequence model. 550 551 In this context "attention" means that, during decoding, the RNN can look up 552 information in the additional tensor attention_states, and it does this by 553 focusing on a few entries from the tensor. This model has proven to yield 554 especially good results in a number of sequence-to-sequence tasks. This 555 implementation is based on http://arxiv.org/abs/1412.7449 (see below for 556 details). It is recommended for complex sequence-to-sequence tasks. 557 558 Args: 559 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 560 initial_state: 2D Tensor [batch_size x cell.state_size]. 561 attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 562 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 563 output_size: Size of the output vectors; if None, we use cell.output_size. 564 num_heads: Number of attention heads that read from attention_states. 565 loop_function: If not None, this function will be applied to i-th output 566 in order to generate i+1-th input, and decoder_inputs will be ignored, 567 except for the first element ("GO" symbol). This can be used for decoding, 568 but also for training to emulate http://arxiv.org/abs/1506.03099. 569 Signature -- loop_function(prev, i) = next 570 * prev is a 2D Tensor of shape [batch_size x output_size], 571 * i is an integer, the step number (when advanced control is needed), 572 * next is a 2D Tensor of shape [batch_size x input_size]. 573 dtype: The dtype to use for the RNN initial state (default: tf.float32). 574 scope: VariableScope for the created subgraph; default: "attention_decoder". 575 initial_state_attention: If False (default), initial attentions are zero. 576 If True, initialize the attentions from the initial state and attention 577 states -- useful when we wish to resume decoding from a previously 578 stored decoder state and attention states. 579 580 Returns: 581 A tuple of the form (outputs, state), where: 582 outputs: A list of the same length as decoder_inputs of 2D Tensors of 583 shape [batch_size x output_size]. These represent the generated outputs. 584 Output i is computed from input i (which is either the i-th element 585 of decoder_inputs or loop_function(output {i-1}, i)) as follows. 586 First, we run the cell on a combination of the input and previous 587 attention masks: 588 cell_output, new_state = cell(linear(input, prev_attn), prev_state). 589 Then, we calculate new attention masks: 590 new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 591 and then we calculate the output: 592 output = linear(cell_output, new_attn). 593 state: The state of each decoder cell the final time-step. 594 It is a 2D Tensor of shape [batch_size x cell.state_size]. 595 596 Raises: 597 ValueError: when num_heads is not positive, there are no inputs, shapes 598 of attention_states are not set, or input size cannot be inferred 599 from the input. 600 """ 601 if not decoder_inputs: 602 raise ValueError("Must provide at least 1 input to attention decoder.") 603 if num_heads < 1: 604 raise ValueError("With less than 1 heads, use a non-attention decoder.") 605 if attention_states.get_shape()[2].value is None: 606 raise ValueError("Shape[2] of attention_states must be known: %s" % 607 attention_states.get_shape()) 608 if output_size is None: 609 output_size = cell.output_size 610 611 with variable_scope.variable_scope( 612 scope or "attention_decoder", dtype=dtype) as scope: 613 dtype = scope.dtype 614 615 batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 616 attn_length = attention_states.get_shape()[1].value 617 if attn_length is None: 618 attn_length = array_ops.shape(attention_states)[1] 619 attn_size = attention_states.get_shape()[2].value 620 621 # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 622 hidden = array_ops.reshape(attention_states, 623 [-1, attn_length, 1, attn_size]) 624 hidden_features = [] 625 v = [] 626 attention_vec_size = attn_size # Size of query vectors for attention. 627 for a in xrange(num_heads): 628 k = variable_scope.get_variable("AttnW_%d" % a, 629 [1, 1, attn_size, attention_vec_size]) 630 hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 631 v.append( 632 variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) 633 634 state = initial_state 635 636 def attention(query): 637 """Put attention masks on hidden using hidden_features and query.""" 638 ds = [] # Results of attention reads will be stored here. 639 if nest.is_sequence(query): # If the query is a tuple, flatten it. 640 query_list = nest.flatten(query) 641 for q in query_list: # Check that ndims == 2 if specified. 642 ndims = q.get_shape().ndims 643 if ndims: 644 assert ndims == 2 645 query = array_ops.concat(query_list, 1) 646 for a in xrange(num_heads): 647 with variable_scope.variable_scope("Attention_%d" % a): 648 y = Linear(query, attention_vec_size, True)(query) 649 y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 650 # Attention mask is a softmax of v^T * tanh(...). 651 s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), 652 [2, 3]) 653 a = nn_ops.softmax(s) 654 # Now calculate the attention-weighted vector d. 655 d = math_ops.reduce_sum( 656 array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) 657 ds.append(array_ops.reshape(d, [-1, attn_size])) 658 return ds 659 660 outputs = [] 661 prev = None 662 batch_attn_size = array_ops.stack([batch_size, attn_size]) 663 attns = [ 664 array_ops.zeros( 665 batch_attn_size, dtype=dtype) for _ in xrange(num_heads) 666 ] 667 for a in attns: # Ensure the second shape of attention vectors is set. 668 a.set_shape([None, attn_size]) 669 if initial_state_attention: 670 attns = attention(initial_state) 671 for i, inp in enumerate(decoder_inputs): 672 if i > 0: 673 variable_scope.get_variable_scope().reuse_variables() 674 # If loop_function is set, we use it instead of decoder_inputs. 675 if loop_function is not None and prev is not None: 676 with variable_scope.variable_scope("loop_function", reuse=True): 677 inp = loop_function(prev, i) 678 # Merge input and previous attentions into one vector of the right size. 679 input_size = inp.get_shape().with_rank(2)[1] 680 if input_size.value is None: 681 raise ValueError("Could not infer input size from input: %s" % inp.name) 682 683 inputs = [inp] + attns 684 x = Linear(inputs, input_size, True)(inputs) 685 # Run the RNN. 686 cell_output, state = cell(x, state) 687 # Run the attention mechanism. 688 if i == 0 and initial_state_attention: 689 with variable_scope.variable_scope( 690 variable_scope.get_variable_scope(), reuse=True): 691 attns = attention(state) 692 else: 693 attns = attention(state) 694 695 with variable_scope.variable_scope("AttnOutputProjection"): 696 inputs = [cell_output] + attns 697 output = Linear(inputs, output_size, True)(inputs) 698 if loop_function is not None: 699 prev = output 700 outputs.append(output) 701 702 return outputs, state 703 704 705def embedding_attention_decoder(decoder_inputs, 706 initial_state, 707 attention_states, 708 cell, 709 num_symbols, 710 embedding_size, 711 num_heads=1, 712 output_size=None, 713 output_projection=None, 714 feed_previous=False, 715 update_embedding_for_previous=True, 716 dtype=None, 717 scope=None, 718 initial_state_attention=False): 719 """RNN decoder with embedding and attention and a pure-decoding option. 720 721 Args: 722 decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 723 initial_state: 2D Tensor [batch_size x cell.state_size]. 724 attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 725 cell: tf.nn.rnn_cell.RNNCell defining the cell function. 726 num_symbols: Integer, how many symbols come into the embedding. 727 embedding_size: Integer, the length of the embedding vector for each symbol. 728 num_heads: Number of attention heads that read from attention_states. 729 output_size: Size of the output vectors; if None, use output_size. 730 output_projection: None or a pair (W, B) of output projection weights and 731 biases; W has shape [output_size x num_symbols] and B has shape 732 [num_symbols]; if provided and feed_previous=True, each fed previous 733 output will first be multiplied by W and added B. 734 feed_previous: Boolean; if True, only the first of decoder_inputs will be 735 used (the "GO" symbol), and all other decoder inputs will be generated by: 736 next = embedding_lookup(embedding, argmax(previous_output)), 737 In effect, this implements a greedy decoder. It can also be used 738 during training to emulate http://arxiv.org/abs/1506.03099. 739 If False, decoder_inputs are used as given (the standard decoder case). 740 update_embedding_for_previous: Boolean; if False and feed_previous=True, 741 only the embedding for the first symbol of decoder_inputs (the "GO" 742 symbol) will be updated by back propagation. Embeddings for the symbols 743 generated from the decoder itself remain unchanged. This parameter has 744 no effect if feed_previous=False. 745 dtype: The dtype to use for the RNN initial states (default: tf.float32). 746 scope: VariableScope for the created subgraph; defaults to 747 "embedding_attention_decoder". 748 initial_state_attention: If False (default), initial attentions are zero. 749 If True, initialize the attentions from the initial state and attention 750 states -- useful when we wish to resume decoding from a previously 751 stored decoder state and attention states. 752 753 Returns: 754 A tuple of the form (outputs, state), where: 755 outputs: A list of the same length as decoder_inputs of 2D Tensors with 756 shape [batch_size x output_size] containing the generated outputs. 757 state: The state of each decoder cell at the final time-step. 758 It is a 2D Tensor of shape [batch_size x cell.state_size]. 759 760 Raises: 761 ValueError: When output_projection has the wrong shape. 762 """ 763 if output_size is None: 764 output_size = cell.output_size 765 if output_projection is not None: 766 proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 767 proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 768 769 with variable_scope.variable_scope( 770 scope or "embedding_attention_decoder", dtype=dtype) as scope: 771 772 embedding = variable_scope.get_variable("embedding", 773 [num_symbols, embedding_size]) 774 loop_function = _extract_argmax_and_embed( 775 embedding, output_projection, 776 update_embedding_for_previous) if feed_previous else None 777 emb_inp = [ 778 embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs 779 ] 780 return attention_decoder( 781 emb_inp, 782 initial_state, 783 attention_states, 784 cell, 785 output_size=output_size, 786 num_heads=num_heads, 787 loop_function=loop_function, 788 initial_state_attention=initial_state_attention) 789 790 791def embedding_attention_seq2seq(encoder_inputs, 792 decoder_inputs, 793 cell, 794 num_encoder_symbols, 795 num_decoder_symbols, 796 embedding_size, 797 num_heads=1, 798 output_projection=None, 799 feed_previous=False, 800 dtype=None, 801 scope=None, 802 initial_state_attention=False): 803 """Embedding sequence-to-sequence model with attention. 804 805 This model first embeds encoder_inputs by a newly created embedding (of shape 806 [num_encoder_symbols x input_size]). Then it runs an RNN to encode 807 embedded encoder_inputs into a state vector. It keeps the outputs of this 808 RNN at every step to use for attention later. Next, it embeds decoder_inputs 809 by another newly created embedding (of shape [num_decoder_symbols x 810 input_size]). Then it runs attention decoder, initialized with the last 811 encoder state, on embedded decoder_inputs and attending to encoder outputs. 812 813 Warning: when output_projection is None, the size of the attention vectors 814 and variables will be made proportional to num_decoder_symbols, can be large. 815 816 Args: 817 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 818 decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 819 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 820 num_encoder_symbols: Integer; number of symbols on the encoder side. 821 num_decoder_symbols: Integer; number of symbols on the decoder side. 822 embedding_size: Integer, the length of the embedding vector for each symbol. 823 num_heads: Number of attention heads that read from attention_states. 824 output_projection: None or a pair (W, B) of output projection weights and 825 biases; W has shape [output_size x num_decoder_symbols] and B has 826 shape [num_decoder_symbols]; if provided and feed_previous=True, each 827 fed previous output will first be multiplied by W and added B. 828 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 829 of decoder_inputs will be used (the "GO" symbol), and all other decoder 830 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 831 If False, decoder_inputs are used as given (the standard decoder case). 832 dtype: The dtype of the initial RNN state (default: tf.float32). 833 scope: VariableScope for the created subgraph; defaults to 834 "embedding_attention_seq2seq". 835 initial_state_attention: If False (default), initial attentions are zero. 836 If True, initialize the attentions from the initial state and attention 837 states. 838 839 Returns: 840 A tuple of the form (outputs, state), where: 841 outputs: A list of the same length as decoder_inputs of 2D Tensors with 842 shape [batch_size x num_decoder_symbols] containing the generated 843 outputs. 844 state: The state of each decoder cell at the final time-step. 845 It is a 2D Tensor of shape [batch_size x cell.state_size]. 846 """ 847 with variable_scope.variable_scope( 848 scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 849 dtype = scope.dtype 850 # Encoder. 851 encoder_cell = copy.deepcopy(cell) 852 encoder_cell = core_rnn_cell.EmbeddingWrapper( 853 encoder_cell, 854 embedding_classes=num_encoder_symbols, 855 embedding_size=embedding_size) 856 encoder_outputs, encoder_state = rnn.static_rnn( 857 encoder_cell, encoder_inputs, dtype=dtype) 858 859 # First calculate a concatenation of encoder outputs to put attention on. 860 top_states = [ 861 array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs 862 ] 863 attention_states = array_ops.concat(top_states, 1) 864 865 # Decoder. 866 output_size = None 867 if output_projection is None: 868 cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 869 output_size = num_decoder_symbols 870 871 if isinstance(feed_previous, bool): 872 return embedding_attention_decoder( 873 decoder_inputs, 874 encoder_state, 875 attention_states, 876 cell, 877 num_decoder_symbols, 878 embedding_size, 879 num_heads=num_heads, 880 output_size=output_size, 881 output_projection=output_projection, 882 feed_previous=feed_previous, 883 initial_state_attention=initial_state_attention) 884 885 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 886 def decoder(feed_previous_bool): 887 reuse = None if feed_previous_bool else True 888 with variable_scope.variable_scope( 889 variable_scope.get_variable_scope(), reuse=reuse): 890 outputs, state = embedding_attention_decoder( 891 decoder_inputs, 892 encoder_state, 893 attention_states, 894 cell, 895 num_decoder_symbols, 896 embedding_size, 897 num_heads=num_heads, 898 output_size=output_size, 899 output_projection=output_projection, 900 feed_previous=feed_previous_bool, 901 update_embedding_for_previous=False, 902 initial_state_attention=initial_state_attention) 903 state_list = [state] 904 if nest.is_sequence(state): 905 state_list = nest.flatten(state) 906 return outputs + state_list 907 908 outputs_and_state = control_flow_ops.cond(feed_previous, 909 lambda: decoder(True), 910 lambda: decoder(False)) 911 outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 912 state_list = outputs_and_state[outputs_len:] 913 state = state_list[0] 914 if nest.is_sequence(encoder_state): 915 state = nest.pack_sequence_as( 916 structure=encoder_state, flat_sequence=state_list) 917 return outputs_and_state[:outputs_len], state 918 919 920def one2many_rnn_seq2seq(encoder_inputs, 921 decoder_inputs_dict, 922 enc_cell, 923 dec_cells_dict, 924 num_encoder_symbols, 925 num_decoder_symbols_dict, 926 embedding_size, 927 feed_previous=False, 928 dtype=None, 929 scope=None): 930 """One-to-many RNN sequence-to-sequence model (multi-task). 931 932 This is a multi-task sequence-to-sequence model with one encoder and multiple 933 decoders. Reference to multi-task sequence-to-sequence learning can be found 934 here: http://arxiv.org/abs/1511.06114 935 936 Args: 937 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 938 decoder_inputs_dict: A dictionary mapping decoder name (string) to 939 the corresponding decoder_inputs; each decoder_inputs is a list of 1D 940 Tensors of shape [batch_size]; num_decoders is defined as 941 len(decoder_inputs_dict). 942 enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and 943 size. 944 dec_cells_dict: A dictionary mapping encoder name (string) to an 945 instance of tf.nn.rnn_cell.RNNCell. 946 num_encoder_symbols: Integer; number of symbols on the encoder side. 947 num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an 948 integer specifying number of symbols for the corresponding decoder; 949 len(num_decoder_symbols_dict) must be equal to num_decoders. 950 embedding_size: Integer, the length of the embedding vector for each symbol. 951 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of 952 decoder_inputs will be used (the "GO" symbol), and all other decoder 953 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 954 If False, decoder_inputs are used as given (the standard decoder case). 955 dtype: The dtype of the initial state for both the encoder and encoder 956 rnn cells (default: tf.float32). 957 scope: VariableScope for the created subgraph; defaults to 958 "one2many_rnn_seq2seq" 959 960 Returns: 961 A tuple of the form (outputs_dict, state_dict), where: 962 outputs_dict: A mapping from decoder name (string) to a list of the same 963 length as decoder_inputs_dict[name]; each element in the list is a 2D 964 Tensors with shape [batch_size x num_decoder_symbol_list[name]] 965 containing the generated outputs. 966 state_dict: A mapping from decoder name (string) to the final state of the 967 corresponding decoder RNN; it is a 2D Tensor of shape 968 [batch_size x cell.state_size]. 969 970 Raises: 971 TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell. 972 ValueError: if len(dec_cells) != len(decoder_inputs_dict). 973 """ 974 outputs_dict = {} 975 state_dict = {} 976 977 if not isinstance(enc_cell, rnn_cell_impl.RNNCell): 978 raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell)) 979 if set(dec_cells_dict) != set(decoder_inputs_dict): 980 raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict") 981 for dec_cell in dec_cells_dict.values(): 982 if not isinstance(dec_cell, rnn_cell_impl.RNNCell): 983 raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell)) 984 985 with variable_scope.variable_scope( 986 scope or "one2many_rnn_seq2seq", dtype=dtype) as scope: 987 dtype = scope.dtype 988 989 # Encoder. 990 enc_cell = core_rnn_cell.EmbeddingWrapper( 991 enc_cell, 992 embedding_classes=num_encoder_symbols, 993 embedding_size=embedding_size) 994 _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) 995 996 # Decoder. 997 for name, decoder_inputs in decoder_inputs_dict.items(): 998 num_decoder_symbols = num_decoder_symbols_dict[name] 999 dec_cell = dec_cells_dict[name] 1000 1001 with variable_scope.variable_scope("one2many_decoder_" + str( 1002 name)) as scope: 1003 dec_cell = core_rnn_cell.OutputProjectionWrapper( 1004 dec_cell, num_decoder_symbols) 1005 if isinstance(feed_previous, bool): 1006 outputs, state = embedding_rnn_decoder( 1007 decoder_inputs, 1008 encoder_state, 1009 dec_cell, 1010 num_decoder_symbols, 1011 embedding_size, 1012 feed_previous=feed_previous) 1013 else: 1014 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 1015 def filled_embedding_rnn_decoder(feed_previous): 1016 """The current decoder with a fixed feed_previous parameter.""" 1017 # pylint: disable=cell-var-from-loop 1018 reuse = None if feed_previous else True 1019 vs = variable_scope.get_variable_scope() 1020 with variable_scope.variable_scope(vs, reuse=reuse): 1021 outputs, state = embedding_rnn_decoder( 1022 decoder_inputs, 1023 encoder_state, 1024 dec_cell, 1025 num_decoder_symbols, 1026 embedding_size, 1027 feed_previous=feed_previous) 1028 # pylint: enable=cell-var-from-loop 1029 state_list = [state] 1030 if nest.is_sequence(state): 1031 state_list = nest.flatten(state) 1032 return outputs + state_list 1033 1034 outputs_and_state = control_flow_ops.cond( 1035 feed_previous, lambda: filled_embedding_rnn_decoder(True), 1036 lambda: filled_embedding_rnn_decoder(False)) 1037 # Outputs length is the same as for decoder inputs. 1038 outputs_len = len(decoder_inputs) 1039 outputs = outputs_and_state[:outputs_len] 1040 state_list = outputs_and_state[outputs_len:] 1041 state = state_list[0] 1042 if nest.is_sequence(encoder_state): 1043 state = nest.pack_sequence_as( 1044 structure=encoder_state, flat_sequence=state_list) 1045 outputs_dict[name] = outputs 1046 state_dict[name] = state 1047 1048 return outputs_dict, state_dict 1049 1050 1051def sequence_loss_by_example(logits, 1052 targets, 1053 weights, 1054 average_across_timesteps=True, 1055 softmax_loss_function=None, 1056 name=None): 1057 """Weighted cross-entropy loss for a sequence of logits (per example). 1058 1059 Args: 1060 logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1061 targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1062 weights: List of 1D batch-sized float-Tensors of the same length as logits. 1063 average_across_timesteps: If set, divide the returned cost by the total 1064 label weight. 1065 softmax_loss_function: Function (labels, logits) -> loss-batch 1066 to be used instead of the standard softmax (the default if this is None). 1067 **Note that to avoid confusion, it is required for the function to accept 1068 named arguments.** 1069 name: Optional name for this operation, default: "sequence_loss_by_example". 1070 1071 Returns: 1072 1D batch-sized float Tensor: The log-perplexity for each sequence. 1073 1074 Raises: 1075 ValueError: If len(logits) is different from len(targets) or len(weights). 1076 """ 1077 if len(targets) != len(logits) or len(weights) != len(logits): 1078 raise ValueError("Lengths of logits, weights, and targets must be the same " 1079 "%d, %d, %d." % (len(logits), len(weights), len(targets))) 1080 with ops.name_scope(name, "sequence_loss_by_example", 1081 logits + targets + weights): 1082 log_perp_list = [] 1083 for logit, target, weight in zip(logits, targets, weights): 1084 if softmax_loss_function is None: 1085 # TODO(irving,ebrevdo): This reshape is needed because 1086 # sequence_loss_by_example is called with scalars sometimes, which 1087 # violates our general scalar strictness policy. 1088 target = array_ops.reshape(target, [-1]) 1089 crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 1090 labels=target, logits=logit) 1091 else: 1092 crossent = softmax_loss_function(labels=target, logits=logit) 1093 log_perp_list.append(crossent * weight) 1094 log_perps = math_ops.add_n(log_perp_list) 1095 if average_across_timesteps: 1096 total_size = math_ops.add_n(weights) 1097 total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 1098 log_perps /= total_size 1099 return log_perps 1100 1101 1102def sequence_loss(logits, 1103 targets, 1104 weights, 1105 average_across_timesteps=True, 1106 average_across_batch=True, 1107 softmax_loss_function=None, 1108 name=None): 1109 """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 1110 1111 Args: 1112 logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1113 targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1114 weights: List of 1D batch-sized float-Tensors of the same length as logits. 1115 average_across_timesteps: If set, divide the returned cost by the total 1116 label weight. 1117 average_across_batch: If set, divide the returned cost by the batch size. 1118 softmax_loss_function: Function (labels, logits) -> loss-batch 1119 to be used instead of the standard softmax (the default if this is None). 1120 **Note that to avoid confusion, it is required for the function to accept 1121 named arguments.** 1122 name: Optional name for this operation, defaults to "sequence_loss". 1123 1124 Returns: 1125 A scalar float Tensor: The average log-perplexity per symbol (weighted). 1126 1127 Raises: 1128 ValueError: If len(logits) is different from len(targets) or len(weights). 1129 """ 1130 with ops.name_scope(name, "sequence_loss", logits + targets + weights): 1131 cost = math_ops.reduce_sum( 1132 sequence_loss_by_example( 1133 logits, 1134 targets, 1135 weights, 1136 average_across_timesteps=average_across_timesteps, 1137 softmax_loss_function=softmax_loss_function)) 1138 if average_across_batch: 1139 batch_size = array_ops.shape(targets[0])[0] 1140 return cost / math_ops.cast(batch_size, cost.dtype) 1141 else: 1142 return cost 1143 1144 1145def model_with_buckets(encoder_inputs, 1146 decoder_inputs, 1147 targets, 1148 weights, 1149 buckets, 1150 seq2seq, 1151 softmax_loss_function=None, 1152 per_example_loss=False, 1153 name=None): 1154 """Create a sequence-to-sequence model with support for bucketing. 1155 1156 The seq2seq argument is a function that defines a sequence-to-sequence model, 1157 e.g., seq2seq = lambda x, y: basic_rnn_seq2seq( 1158 x, y, rnn_cell.GRUCell(24)) 1159 1160 Args: 1161 encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 1162 decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 1163 targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 1164 weights: List of 1D batch-sized float-Tensors to weight the targets. 1165 buckets: A list of pairs of (input size, output size) for each bucket. 1166 seq2seq: A sequence-to-sequence model function; it takes 2 input that 1167 agree with encoder_inputs and decoder_inputs, and returns a pair 1168 consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 1169 softmax_loss_function: Function (labels, logits) -> loss-batch 1170 to be used instead of the standard softmax (the default if this is None). 1171 **Note that to avoid confusion, it is required for the function to accept 1172 named arguments.** 1173 per_example_loss: Boolean. If set, the returned loss will be a batch-sized 1174 tensor of losses for each sequence in the batch. If unset, it will be 1175 a scalar with the averaged loss from all examples. 1176 name: Optional name for this operation, defaults to "model_with_buckets". 1177 1178 Returns: 1179 A tuple of the form (outputs, losses), where: 1180 outputs: The outputs for each bucket. Its j'th element consists of a list 1181 of 2D Tensors. The shape of output tensors can be either 1182 [batch_size x output_size] or [batch_size x num_decoder_symbols] 1183 depending on the seq2seq model used. 1184 losses: List of scalar Tensors, representing losses for each bucket, or, 1185 if per_example_loss is set, a list of 1D batch-sized float Tensors. 1186 1187 Raises: 1188 ValueError: If length of encoder_inputs, targets, or weights is smaller 1189 than the largest (last) bucket. 1190 """ 1191 if len(encoder_inputs) < buckets[-1][0]: 1192 raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 1193 "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 1194 if len(targets) < buckets[-1][1]: 1195 raise ValueError("Length of targets (%d) must be at least that of last " 1196 "bucket (%d)." % (len(targets), buckets[-1][1])) 1197 if len(weights) < buckets[-1][1]: 1198 raise ValueError("Length of weights (%d) must be at least that of last " 1199 "bucket (%d)." % (len(weights), buckets[-1][1])) 1200 1201 all_inputs = encoder_inputs + decoder_inputs + targets + weights 1202 losses = [] 1203 outputs = [] 1204 with ops.name_scope(name, "model_with_buckets", all_inputs): 1205 for j, bucket in enumerate(buckets): 1206 with variable_scope.variable_scope( 1207 variable_scope.get_variable_scope(), reuse=True if j > 0 else None): 1208 bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], 1209 decoder_inputs[:bucket[1]]) 1210 outputs.append(bucket_outputs) 1211 if per_example_loss: 1212 losses.append( 1213 sequence_loss_by_example( 1214 outputs[-1], 1215 targets[:bucket[1]], 1216 weights[:bucket[1]], 1217 softmax_loss_function=softmax_loss_function)) 1218 else: 1219 losses.append( 1220 sequence_loss( 1221 outputs[-1], 1222 targets[:bucket[1]], 1223 weights[:bucket[1]], 1224 softmax_loss_function=softmax_loss_function)) 1225 1226 return outputs, losses 1227