1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for creating sequence-to-sequence models in TensorFlow. 16 17Sequence-to-sequence recurrent neural networks can learn complex functions 18that map input sequences to output sequences. These models yield very good 19results on a number of tasks, such as speech recognition, parsing, machine 20translation, or even constructing automated replies to emails. 21 22Before using this module, it is recommended to read the TensorFlow tutorial 23on sequence-to-sequence models. It explains the basic concepts of this module 24and shows an end-to-end example of how to build a translation model. 25 https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 26 27Here is an overview of functions available in this module. They all use 28a very similar interface, so after reading the above tutorial and using 29one of them, others should be easy to substitute. 30 31* Full sequence-to-sequence models. 32 - basic_rnn_seq2seq: The most basic RNN-RNN model. 33 - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights. 34 - embedding_rnn_seq2seq: The basic model with input embedding. 35 - embedding_tied_rnn_seq2seq: The tied model with input embedding. 36 - embedding_attention_seq2seq: Advanced model with input embedding and 37 the neural attention mechanism; recommended for complex tasks. 38 39* Multi-task sequence-to-sequence models. 40 - one2many_rnn_seq2seq: The embedding model with multiple decoders. 41 42* Decoders (when you write your own encoder, you can use these to decode; 43 e.g., if you want to write a model that generates captions for images). 44 - rnn_decoder: The basic decoder based on a pure RNN. 45 - attention_decoder: A decoder that uses the attention mechanism. 46 47* Losses. 48 - sequence_loss: Loss for a sequence model returning average log-perplexity. 49 - sequence_loss_by_example: As above, but not averaging over all examples. 50 51* model_with_buckets: A convenience function to create models with bucketing 52 (see the tutorial above for an explanation of why and how to use it). 53""" 54 55from __future__ import absolute_import 56from __future__ import division 57from __future__ import print_function 58 59import copy 60 61# We disable pylint because we need python3 compatibility. 62from six.moves import xrange # pylint: disable=redefined-builtin 63from six.moves import zip # pylint: disable=redefined-builtin 64 65from tensorflow.contrib.rnn.python.ops import core_rnn_cell 66from tensorflow.python.framework import dtypes 67from tensorflow.python.framework import ops 68from tensorflow.python.ops import array_ops 69from tensorflow.python.ops import control_flow_ops 70from tensorflow.python.ops import embedding_ops 71from tensorflow.python.ops import math_ops 72from tensorflow.python.ops import nn_ops 73from tensorflow.python.ops import rnn 74from tensorflow.python.ops import rnn_cell_impl 75from tensorflow.python.ops import variable_scope 76from tensorflow.python.util import nest 77 78# TODO(ebrevdo): Remove once _linear is fully deprecated. 79Linear = core_rnn_cell._Linear # pylint: disable=protected-access,invalid-name 80 81 82def _extract_argmax_and_embed(embedding, 83 output_projection=None, 84 update_embedding=True): 85 """Get a loop_function that extracts the previous symbol and embeds it. 86 87 Args: 88 embedding: embedding tensor for symbols. 89 output_projection: None or a pair (W, B). If provided, each fed previous 90 output will first be multiplied by W and added B. 91 update_embedding: Boolean; if False, the gradients will not propagate 92 through the embeddings. 93 94 Returns: 95 A loop function. 96 """ 97 98 def loop_function(prev, _): 99 if output_projection is not None: 100 prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1]) 101 prev_symbol = math_ops.argmax(prev, 1) 102 # Note that gradients will not propagate through the second parameter of 103 # embedding_lookup. 104 emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 105 if not update_embedding: 106 emb_prev = array_ops.stop_gradient(emb_prev) 107 return emb_prev 108 109 return loop_function 110 111 112def rnn_decoder(decoder_inputs, 113 initial_state, 114 cell, 115 loop_function=None, 116 scope=None): 117 """RNN decoder for the sequence-to-sequence model. 118 119 Args: 120 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 121 initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 122 cell: rnn_cell.RNNCell defining the cell function and size. 123 loop_function: If not None, this function will be applied to the i-th output 124 in order to generate the i+1-st input, and decoder_inputs will be ignored, 125 except for the first element ("GO" symbol). This can be used for decoding, 126 but also for training to emulate http://arxiv.org/abs/1506.03099. 127 Signature -- loop_function(prev, i) = next 128 * prev is a 2D Tensor of shape [batch_size x output_size], 129 * i is an integer, the step number (when advanced control is needed), 130 * next is a 2D Tensor of shape [batch_size x input_size]. 131 scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 132 133 Returns: 134 A tuple of the form (outputs, state), where: 135 outputs: A list of the same length as decoder_inputs of 2D Tensors with 136 shape [batch_size x output_size] containing generated outputs. 137 state: The state of each cell at the final time-step. 138 It is a 2D Tensor of shape [batch_size x cell.state_size]. 139 (Note that in some cases, like basic RNN cell or GRU cell, outputs and 140 states can be the same. They are different for LSTM cells though.) 141 """ 142 with variable_scope.variable_scope(scope or "rnn_decoder"): 143 state = initial_state 144 outputs = [] 145 prev = None 146 for i, inp in enumerate(decoder_inputs): 147 if loop_function is not None and prev is not None: 148 with variable_scope.variable_scope("loop_function", reuse=True): 149 inp = loop_function(prev, i) 150 if i > 0: 151 variable_scope.get_variable_scope().reuse_variables() 152 output, state = cell(inp, state) 153 outputs.append(output) 154 if loop_function is not None: 155 prev = output 156 return outputs, state 157 158 159def basic_rnn_seq2seq(encoder_inputs, 160 decoder_inputs, 161 cell, 162 dtype=dtypes.float32, 163 scope=None): 164 """Basic RNN sequence-to-sequence model. 165 166 This model first runs an RNN to encode encoder_inputs into a state vector, 167 then runs decoder, initialized with the last encoder state, on decoder_inputs. 168 Encoder and decoder use the same RNN cell type, but don't share parameters. 169 170 Args: 171 encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 172 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 173 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 174 dtype: The dtype of the initial state of the RNN cell (default: tf.float32). 175 scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". 176 177 Returns: 178 A tuple of the form (outputs, state), where: 179 outputs: A list of the same length as decoder_inputs of 2D Tensors with 180 shape [batch_size x output_size] containing the generated outputs. 181 state: The state of each decoder cell in the final time-step. 182 It is a 2D Tensor of shape [batch_size x cell.state_size]. 183 """ 184 with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): 185 enc_cell = copy.deepcopy(cell) 186 _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) 187 return rnn_decoder(decoder_inputs, enc_state, cell) 188 189 190def tied_rnn_seq2seq(encoder_inputs, 191 decoder_inputs, 192 cell, 193 loop_function=None, 194 dtype=dtypes.float32, 195 scope=None): 196 """RNN sequence-to-sequence model with tied encoder and decoder parameters. 197 198 This model first runs an RNN to encode encoder_inputs into a state vector, and 199 then runs decoder, initialized with the last encoder state, on decoder_inputs. 200 Encoder and decoder use the same RNN cell and share parameters. 201 202 Args: 203 encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 204 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 205 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 206 loop_function: If not None, this function will be applied to i-th output 207 in order to generate i+1-th input, and decoder_inputs will be ignored, 208 except for the first element ("GO" symbol), see rnn_decoder for details. 209 dtype: The dtype of the initial state of the rnn cell (default: tf.float32). 210 scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". 211 212 Returns: 213 A tuple of the form (outputs, state), where: 214 outputs: A list of the same length as decoder_inputs of 2D Tensors with 215 shape [batch_size x output_size] containing the generated outputs. 216 state: The state of each decoder cell in each time-step. This is a list 217 with length len(decoder_inputs) -- one item for each time-step. 218 It is a 2D Tensor of shape [batch_size x cell.state_size]. 219 """ 220 with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): 221 scope = scope or "tied_rnn_seq2seq" 222 _, enc_state = rnn.static_rnn( 223 cell, encoder_inputs, dtype=dtype, scope=scope) 224 variable_scope.get_variable_scope().reuse_variables() 225 return rnn_decoder( 226 decoder_inputs, 227 enc_state, 228 cell, 229 loop_function=loop_function, 230 scope=scope) 231 232 233def embedding_rnn_decoder(decoder_inputs, 234 initial_state, 235 cell, 236 num_symbols, 237 embedding_size, 238 output_projection=None, 239 feed_previous=False, 240 update_embedding_for_previous=True, 241 scope=None): 242 """RNN decoder with embedding and a pure-decoding option. 243 244 Args: 245 decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 246 initial_state: 2D Tensor [batch_size x cell.state_size]. 247 cell: tf.nn.rnn_cell.RNNCell defining the cell function. 248 num_symbols: Integer, how many symbols come into the embedding. 249 embedding_size: Integer, the length of the embedding vector for each symbol. 250 output_projection: None or a pair (W, B) of output projection weights and 251 biases; W has shape [output_size x num_symbols] and B has 252 shape [num_symbols]; if provided and feed_previous=True, each fed 253 previous output will first be multiplied by W and added B. 254 feed_previous: Boolean; if True, only the first of decoder_inputs will be 255 used (the "GO" symbol), and all other decoder inputs will be generated by: 256 next = embedding_lookup(embedding, argmax(previous_output)), 257 In effect, this implements a greedy decoder. It can also be used 258 during training to emulate http://arxiv.org/abs/1506.03099. 259 If False, decoder_inputs are used as given (the standard decoder case). 260 update_embedding_for_previous: Boolean; if False and feed_previous=True, 261 only the embedding for the first symbol of decoder_inputs (the "GO" 262 symbol) will be updated by back propagation. Embeddings for the symbols 263 generated from the decoder itself remain unchanged. This parameter has 264 no effect if feed_previous=False. 265 scope: VariableScope for the created subgraph; defaults to 266 "embedding_rnn_decoder". 267 268 Returns: 269 A tuple of the form (outputs, state), where: 270 outputs: A list of the same length as decoder_inputs of 2D Tensors. The 271 output is of shape [batch_size x cell.output_size] when 272 output_projection is not None (and represents the dense representation 273 of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 274 when output_projection is None. 275 state: The state of each decoder cell in each time-step. This is a list 276 with length len(decoder_inputs) -- one item for each time-step. 277 It is a 2D Tensor of shape [batch_size x cell.state_size]. 278 279 Raises: 280 ValueError: When output_projection has the wrong shape. 281 """ 282 with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope: 283 if output_projection is not None: 284 dtype = scope.dtype 285 proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 286 proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 287 proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 288 proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 289 290 embedding = variable_scope.get_variable("embedding", 291 [num_symbols, embedding_size]) 292 loop_function = _extract_argmax_and_embed( 293 embedding, output_projection, 294 update_embedding_for_previous) if feed_previous else None 295 emb_inp = (embedding_ops.embedding_lookup(embedding, i) 296 for i in decoder_inputs) 297 return rnn_decoder( 298 emb_inp, initial_state, cell, loop_function=loop_function) 299 300 301def embedding_rnn_seq2seq(encoder_inputs, 302 decoder_inputs, 303 cell, 304 num_encoder_symbols, 305 num_decoder_symbols, 306 embedding_size, 307 output_projection=None, 308 feed_previous=False, 309 dtype=None, 310 scope=None): 311 """Embedding RNN sequence-to-sequence model. 312 313 This model first embeds encoder_inputs by a newly created embedding (of shape 314 [num_encoder_symbols x input_size]). Then it runs an RNN to encode 315 embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs 316 by another newly created embedding (of shape [num_decoder_symbols x 317 input_size]). Then it runs RNN decoder, initialized with the last 318 encoder state, on embedded decoder_inputs. 319 320 Args: 321 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 322 decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 323 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 324 num_encoder_symbols: Integer; number of symbols on the encoder side. 325 num_decoder_symbols: Integer; number of symbols on the decoder side. 326 embedding_size: Integer, the length of the embedding vector for each symbol. 327 output_projection: None or a pair (W, B) of output projection weights and 328 biases; W has shape [output_size x num_decoder_symbols] and B has 329 shape [num_decoder_symbols]; if provided and feed_previous=True, each 330 fed previous output will first be multiplied by W and added B. 331 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 332 of decoder_inputs will be used (the "GO" symbol), and all other decoder 333 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 334 If False, decoder_inputs are used as given (the standard decoder case). 335 dtype: The dtype of the initial state for both the encoder and encoder 336 rnn cells (default: tf.float32). 337 scope: VariableScope for the created subgraph; defaults to 338 "embedding_rnn_seq2seq" 339 340 Returns: 341 A tuple of the form (outputs, state), where: 342 outputs: A list of the same length as decoder_inputs of 2D Tensors. The 343 output is of shape [batch_size x cell.output_size] when 344 output_projection is not None (and represents the dense representation 345 of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 346 when output_projection is None. 347 state: The state of each decoder cell in each time-step. This is a list 348 with length len(decoder_inputs) -- one item for each time-step. 349 It is a 2D Tensor of shape [batch_size x cell.state_size]. 350 """ 351 with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope: 352 if dtype is not None: 353 scope.set_dtype(dtype) 354 else: 355 dtype = scope.dtype 356 357 # Encoder. 358 encoder_cell = copy.deepcopy(cell) 359 encoder_cell = core_rnn_cell.EmbeddingWrapper( 360 encoder_cell, 361 embedding_classes=num_encoder_symbols, 362 embedding_size=embedding_size) 363 _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) 364 365 # Decoder. 366 if output_projection is None: 367 cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 368 369 if isinstance(feed_previous, bool): 370 return embedding_rnn_decoder( 371 decoder_inputs, 372 encoder_state, 373 cell, 374 num_decoder_symbols, 375 embedding_size, 376 output_projection=output_projection, 377 feed_previous=feed_previous) 378 379 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 380 def decoder(feed_previous_bool): 381 reuse = None if feed_previous_bool else True 382 with variable_scope.variable_scope( 383 variable_scope.get_variable_scope(), reuse=reuse): 384 outputs, state = embedding_rnn_decoder( 385 decoder_inputs, 386 encoder_state, 387 cell, 388 num_decoder_symbols, 389 embedding_size, 390 output_projection=output_projection, 391 feed_previous=feed_previous_bool, 392 update_embedding_for_previous=False) 393 state_list = [state] 394 if nest.is_sequence(state): 395 state_list = nest.flatten(state) 396 return outputs + state_list 397 398 outputs_and_state = control_flow_ops.cond(feed_previous, 399 lambda: decoder(True), 400 lambda: decoder(False)) 401 outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 402 state_list = outputs_and_state[outputs_len:] 403 state = state_list[0] 404 if nest.is_sequence(encoder_state): 405 state = nest.pack_sequence_as( 406 structure=encoder_state, flat_sequence=state_list) 407 return outputs_and_state[:outputs_len], state 408 409 410def embedding_tied_rnn_seq2seq(encoder_inputs, 411 decoder_inputs, 412 cell, 413 num_symbols, 414 embedding_size, 415 num_decoder_symbols=None, 416 output_projection=None, 417 feed_previous=False, 418 dtype=None, 419 scope=None): 420 """Embedding RNN sequence-to-sequence model with tied (shared) parameters. 421 422 This model first embeds encoder_inputs by a newly created embedding (of shape 423 [num_symbols x input_size]). Then it runs an RNN to encode embedded 424 encoder_inputs into a state vector. Next, it embeds decoder_inputs using 425 the same embedding. Then it runs RNN decoder, initialized with the last 426 encoder state, on embedded decoder_inputs. The decoder output is over symbols 427 from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it 428 is over 0 to num_symbols - 1. 429 430 Args: 431 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 432 decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 433 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 434 num_symbols: Integer; number of symbols for both encoder and decoder. 435 embedding_size: Integer, the length of the embedding vector for each symbol. 436 num_decoder_symbols: Integer; number of output symbols for decoder. If 437 provided, the decoder output is over symbols 0 to num_decoder_symbols - 1. 438 Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that 439 this assumes that the vocabulary is set up such that the first 440 num_decoder_symbols of num_symbols are part of decoding. 441 output_projection: None or a pair (W, B) of output projection weights and 442 biases; W has shape [output_size x num_symbols] and B has 443 shape [num_symbols]; if provided and feed_previous=True, each 444 fed previous output will first be multiplied by W and added B. 445 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 446 of decoder_inputs will be used (the "GO" symbol), and all other decoder 447 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 448 If False, decoder_inputs are used as given (the standard decoder case). 449 dtype: The dtype to use for the initial RNN states (default: tf.float32). 450 scope: VariableScope for the created subgraph; defaults to 451 "embedding_tied_rnn_seq2seq". 452 453 Returns: 454 A tuple of the form (outputs, state), where: 455 outputs: A list of the same length as decoder_inputs of 2D Tensors with 456 shape [batch_size x output_symbols] containing the generated 457 outputs where output_symbols = num_decoder_symbols if 458 num_decoder_symbols is not None otherwise output_symbols = num_symbols. 459 state: The state of each decoder cell at the final time-step. 460 It is a 2D Tensor of shape [batch_size x cell.state_size]. 461 462 Raises: 463 ValueError: When output_projection has the wrong shape. 464 """ 465 with variable_scope.variable_scope( 466 scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope: 467 dtype = scope.dtype 468 469 if output_projection is not None: 470 proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 471 proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 472 proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 473 proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 474 475 embedding = variable_scope.get_variable( 476 "embedding", [num_symbols, embedding_size], dtype=dtype) 477 478 emb_encoder_inputs = [ 479 embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs 480 ] 481 emb_decoder_inputs = [ 482 embedding_ops.embedding_lookup(embedding, x) for x in decoder_inputs 483 ] 484 485 output_symbols = num_symbols 486 if num_decoder_symbols is not None: 487 output_symbols = num_decoder_symbols 488 if output_projection is None: 489 cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols) 490 491 if isinstance(feed_previous, bool): 492 loop_function = _extract_argmax_and_embed(embedding, output_projection, 493 True) if feed_previous else None 494 return tied_rnn_seq2seq( 495 emb_encoder_inputs, 496 emb_decoder_inputs, 497 cell, 498 loop_function=loop_function, 499 dtype=dtype) 500 501 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 502 def decoder(feed_previous_bool): 503 loop_function = _extract_argmax_and_embed( 504 embedding, output_projection, False) if feed_previous_bool else None 505 reuse = None if feed_previous_bool else True 506 with variable_scope.variable_scope( 507 variable_scope.get_variable_scope(), reuse=reuse): 508 outputs, state = tied_rnn_seq2seq( 509 emb_encoder_inputs, 510 emb_decoder_inputs, 511 cell, 512 loop_function=loop_function, 513 dtype=dtype) 514 state_list = [state] 515 if nest.is_sequence(state): 516 state_list = nest.flatten(state) 517 return outputs + state_list 518 519 outputs_and_state = control_flow_ops.cond(feed_previous, 520 lambda: decoder(True), 521 lambda: decoder(False)) 522 outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 523 state_list = outputs_and_state[outputs_len:] 524 state = state_list[0] 525 # Calculate zero-state to know it's structure. 526 static_batch_size = encoder_inputs[0].get_shape()[0] 527 for inp in encoder_inputs[1:]: 528 static_batch_size.merge_with(inp.get_shape()[0]) 529 batch_size = static_batch_size.value 530 if batch_size is None: 531 batch_size = array_ops.shape(encoder_inputs[0])[0] 532 zero_state = cell.zero_state(batch_size, dtype) 533 if nest.is_sequence(zero_state): 534 state = nest.pack_sequence_as( 535 structure=zero_state, flat_sequence=state_list) 536 return outputs_and_state[:outputs_len], state 537 538 539def attention_decoder(decoder_inputs, 540 initial_state, 541 attention_states, 542 cell, 543 output_size=None, 544 num_heads=1, 545 loop_function=None, 546 dtype=None, 547 scope=None, 548 initial_state_attention=False): 549 """RNN decoder with attention for the sequence-to-sequence model. 550 551 In this context "attention" means that, during decoding, the RNN can look up 552 information in the additional tensor attention_states, and it does this by 553 focusing on a few entries from the tensor. This model has proven to yield 554 especially good results in a number of sequence-to-sequence tasks. This 555 implementation is based on http://arxiv.org/abs/1412.7449 (see below for 556 details). It is recommended for complex sequence-to-sequence tasks. 557 558 Args: 559 decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 560 initial_state: 2D Tensor [batch_size x cell.state_size]. 561 attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 562 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 563 output_size: Size of the output vectors; if None, we use cell.output_size. 564 num_heads: Number of attention heads that read from attention_states. 565 loop_function: If not None, this function will be applied to i-th output 566 in order to generate i+1-th input, and decoder_inputs will be ignored, 567 except for the first element ("GO" symbol). This can be used for decoding, 568 but also for training to emulate http://arxiv.org/abs/1506.03099. 569 Signature -- loop_function(prev, i) = next 570 * prev is a 2D Tensor of shape [batch_size x output_size], 571 * i is an integer, the step number (when advanced control is needed), 572 * next is a 2D Tensor of shape [batch_size x input_size]. 573 dtype: The dtype to use for the RNN initial state (default: tf.float32). 574 scope: VariableScope for the created subgraph; default: "attention_decoder". 575 initial_state_attention: If False (default), initial attentions are zero. 576 If True, initialize the attentions from the initial state and attention 577 states -- useful when we wish to resume decoding from a previously 578 stored decoder state and attention states. 579 580 Returns: 581 A tuple of the form (outputs, state), where: 582 outputs: A list of the same length as decoder_inputs of 2D Tensors of 583 shape [batch_size x output_size]. These represent the generated outputs. 584 Output i is computed from input i (which is either the i-th element 585 of decoder_inputs or loop_function(output {i-1}, i)) as follows. 586 First, we run the cell on a combination of the input and previous 587 attention masks: 588 cell_output, new_state = cell(linear(input, prev_attn), prev_state). 589 Then, we calculate new attention masks: 590 new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 591 and then we calculate the output: 592 output = linear(cell_output, new_attn). 593 state: The state of each decoder cell the final time-step. 594 It is a 2D Tensor of shape [batch_size x cell.state_size]. 595 596 Raises: 597 ValueError: when num_heads is not positive, there are no inputs, shapes 598 of attention_states are not set, or input size cannot be inferred 599 from the input. 600 """ 601 if not decoder_inputs: 602 raise ValueError("Must provide at least 1 input to attention decoder.") 603 if num_heads < 1: 604 raise ValueError("With less than 1 heads, use a non-attention decoder.") 605 if attention_states.get_shape()[2].value is None: 606 raise ValueError("Shape[2] of attention_states must be known: %s" % 607 attention_states.get_shape()) 608 if output_size is None: 609 output_size = cell.output_size 610 611 with variable_scope.variable_scope( 612 scope or "attention_decoder", dtype=dtype) as scope: 613 dtype = scope.dtype 614 615 batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 616 attn_length = attention_states.get_shape()[1].value 617 if attn_length is None: 618 attn_length = array_ops.shape(attention_states)[1] 619 attn_size = attention_states.get_shape()[2].value 620 621 # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 622 hidden = array_ops.reshape(attention_states, 623 [-1, attn_length, 1, attn_size]) 624 hidden_features = [] 625 v = [] 626 attention_vec_size = attn_size # Size of query vectors for attention. 627 for a in xrange(num_heads): 628 k = variable_scope.get_variable( 629 "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size], 630 dtype=dtype) 631 hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 632 v.append( 633 variable_scope.get_variable( 634 "AttnV_%d" % a, [attention_vec_size], dtype=dtype)) 635 636 state = initial_state 637 638 def attention(query): 639 """Put attention masks on hidden using hidden_features and query.""" 640 ds = [] # Results of attention reads will be stored here. 641 if nest.is_sequence(query): # If the query is a tuple, flatten it. 642 query_list = nest.flatten(query) 643 for q in query_list: # Check that ndims == 2 if specified. 644 ndims = q.get_shape().ndims 645 if ndims: 646 assert ndims == 2 647 query = array_ops.concat(query_list, 1) 648 for a in xrange(num_heads): 649 with variable_scope.variable_scope("Attention_%d" % a): 650 y = Linear(query, attention_vec_size, True)(query) 651 y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 652 y = math_ops.cast(y, dtype) 653 # Attention mask is a softmax of v^T * tanh(...). 654 s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), 655 [2, 3]) 656 a = nn_ops.softmax(math_ops.cast(s, dtype=dtypes.float32)) 657 # Now calculate the attention-weighted vector d. 658 a = math_ops.cast(a, dtype) 659 d = math_ops.reduce_sum( 660 array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) 661 ds.append(array_ops.reshape(d, [-1, attn_size])) 662 return ds 663 664 outputs = [] 665 prev = None 666 batch_attn_size = array_ops.stack([batch_size, attn_size]) 667 attns = [ 668 array_ops.zeros( 669 batch_attn_size, dtype=dtype) for _ in xrange(num_heads) 670 ] 671 for a in attns: # Ensure the second shape of attention vectors is set. 672 a.set_shape([None, attn_size]) 673 if initial_state_attention: 674 attns = attention(initial_state) 675 for i, inp in enumerate(decoder_inputs): 676 if i > 0: 677 variable_scope.get_variable_scope().reuse_variables() 678 # If loop_function is set, we use it instead of decoder_inputs. 679 if loop_function is not None and prev is not None: 680 with variable_scope.variable_scope("loop_function", reuse=True): 681 inp = loop_function(prev, i) 682 # Merge input and previous attentions into one vector of the right size. 683 input_size = inp.get_shape().with_rank(2)[1] 684 if input_size.value is None: 685 raise ValueError("Could not infer input size from input: %s" % inp.name) 686 687 inputs = [inp] + attns 688 inputs = [math_ops.cast(e, dtype) for e in inputs] 689 x = Linear(inputs, input_size, True)(inputs) 690 # Run the RNN. 691 cell_output, state = cell(x, state) 692 # Run the attention mechanism. 693 if i == 0 and initial_state_attention: 694 with variable_scope.variable_scope( 695 variable_scope.get_variable_scope(), reuse=True): 696 attns = attention(state) 697 else: 698 attns = attention(state) 699 700 with variable_scope.variable_scope("AttnOutputProjection"): 701 cell_output = math_ops.cast(cell_output, dtype) 702 inputs = [cell_output] + attns 703 output = Linear(inputs, output_size, True)(inputs) 704 if loop_function is not None: 705 prev = output 706 outputs.append(output) 707 708 return outputs, state 709 710 711def embedding_attention_decoder(decoder_inputs, 712 initial_state, 713 attention_states, 714 cell, 715 num_symbols, 716 embedding_size, 717 num_heads=1, 718 output_size=None, 719 output_projection=None, 720 feed_previous=False, 721 update_embedding_for_previous=True, 722 dtype=None, 723 scope=None, 724 initial_state_attention=False): 725 """RNN decoder with embedding and attention and a pure-decoding option. 726 727 Args: 728 decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 729 initial_state: 2D Tensor [batch_size x cell.state_size]. 730 attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 731 cell: tf.nn.rnn_cell.RNNCell defining the cell function. 732 num_symbols: Integer, how many symbols come into the embedding. 733 embedding_size: Integer, the length of the embedding vector for each symbol. 734 num_heads: Number of attention heads that read from attention_states. 735 output_size: Size of the output vectors; if None, use output_size. 736 output_projection: None or a pair (W, B) of output projection weights and 737 biases; W has shape [output_size x num_symbols] and B has shape 738 [num_symbols]; if provided and feed_previous=True, each fed previous 739 output will first be multiplied by W and added B. 740 feed_previous: Boolean; if True, only the first of decoder_inputs will be 741 used (the "GO" symbol), and all other decoder inputs will be generated by: 742 next = embedding_lookup(embedding, argmax(previous_output)), 743 In effect, this implements a greedy decoder. It can also be used 744 during training to emulate http://arxiv.org/abs/1506.03099. 745 If False, decoder_inputs are used as given (the standard decoder case). 746 update_embedding_for_previous: Boolean; if False and feed_previous=True, 747 only the embedding for the first symbol of decoder_inputs (the "GO" 748 symbol) will be updated by back propagation. Embeddings for the symbols 749 generated from the decoder itself remain unchanged. This parameter has 750 no effect if feed_previous=False. 751 dtype: The dtype to use for the RNN initial states (default: tf.float32). 752 scope: VariableScope for the created subgraph; defaults to 753 "embedding_attention_decoder". 754 initial_state_attention: If False (default), initial attentions are zero. 755 If True, initialize the attentions from the initial state and attention 756 states -- useful when we wish to resume decoding from a previously 757 stored decoder state and attention states. 758 759 Returns: 760 A tuple of the form (outputs, state), where: 761 outputs: A list of the same length as decoder_inputs of 2D Tensors with 762 shape [batch_size x output_size] containing the generated outputs. 763 state: The state of each decoder cell at the final time-step. 764 It is a 2D Tensor of shape [batch_size x cell.state_size]. 765 766 Raises: 767 ValueError: When output_projection has the wrong shape. 768 """ 769 if output_size is None: 770 output_size = cell.output_size 771 if output_projection is not None: 772 proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 773 proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 774 775 with variable_scope.variable_scope( 776 scope or "embedding_attention_decoder", dtype=dtype) as scope: 777 778 embedding = variable_scope.get_variable("embedding", 779 [num_symbols, embedding_size]) 780 loop_function = _extract_argmax_and_embed( 781 embedding, output_projection, 782 update_embedding_for_previous) if feed_previous else None 783 emb_inp = [ 784 embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs 785 ] 786 return attention_decoder( 787 emb_inp, 788 initial_state, 789 attention_states, 790 cell, 791 output_size=output_size, 792 num_heads=num_heads, 793 loop_function=loop_function, 794 initial_state_attention=initial_state_attention) 795 796 797def embedding_attention_seq2seq(encoder_inputs, 798 decoder_inputs, 799 cell, 800 num_encoder_symbols, 801 num_decoder_symbols, 802 embedding_size, 803 num_heads=1, 804 output_projection=None, 805 feed_previous=False, 806 dtype=None, 807 scope=None, 808 initial_state_attention=False): 809 """Embedding sequence-to-sequence model with attention. 810 811 This model first embeds encoder_inputs by a newly created embedding (of shape 812 [num_encoder_symbols x input_size]). Then it runs an RNN to encode 813 embedded encoder_inputs into a state vector. It keeps the outputs of this 814 RNN at every step to use for attention later. Next, it embeds decoder_inputs 815 by another newly created embedding (of shape [num_decoder_symbols x 816 input_size]). Then it runs attention decoder, initialized with the last 817 encoder state, on embedded decoder_inputs and attending to encoder outputs. 818 819 Warning: when output_projection is None, the size of the attention vectors 820 and variables will be made proportional to num_decoder_symbols, can be large. 821 822 Args: 823 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 824 decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 825 cell: tf.nn.rnn_cell.RNNCell defining the cell function and size. 826 num_encoder_symbols: Integer; number of symbols on the encoder side. 827 num_decoder_symbols: Integer; number of symbols on the decoder side. 828 embedding_size: Integer, the length of the embedding vector for each symbol. 829 num_heads: Number of attention heads that read from attention_states. 830 output_projection: None or a pair (W, B) of output projection weights and 831 biases; W has shape [output_size x num_decoder_symbols] and B has 832 shape [num_decoder_symbols]; if provided and feed_previous=True, each 833 fed previous output will first be multiplied by W and added B. 834 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 835 of decoder_inputs will be used (the "GO" symbol), and all other decoder 836 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 837 If False, decoder_inputs are used as given (the standard decoder case). 838 dtype: The dtype of the initial RNN state (default: tf.float32). 839 scope: VariableScope for the created subgraph; defaults to 840 "embedding_attention_seq2seq". 841 initial_state_attention: If False (default), initial attentions are zero. 842 If True, initialize the attentions from the initial state and attention 843 states. 844 845 Returns: 846 A tuple of the form (outputs, state), where: 847 outputs: A list of the same length as decoder_inputs of 2D Tensors with 848 shape [batch_size x num_decoder_symbols] containing the generated 849 outputs. 850 state: The state of each decoder cell at the final time-step. 851 It is a 2D Tensor of shape [batch_size x cell.state_size]. 852 """ 853 with variable_scope.variable_scope( 854 scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 855 dtype = scope.dtype 856 # Encoder. 857 encoder_cell = copy.deepcopy(cell) 858 encoder_cell = core_rnn_cell.EmbeddingWrapper( 859 encoder_cell, 860 embedding_classes=num_encoder_symbols, 861 embedding_size=embedding_size) 862 encoder_outputs, encoder_state = rnn.static_rnn( 863 encoder_cell, encoder_inputs, dtype=dtype) 864 865 # First calculate a concatenation of encoder outputs to put attention on. 866 top_states = [ 867 array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs 868 ] 869 attention_states = array_ops.concat(top_states, 1) 870 871 # Decoder. 872 output_size = None 873 if output_projection is None: 874 cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 875 output_size = num_decoder_symbols 876 877 if isinstance(feed_previous, bool): 878 return embedding_attention_decoder( 879 decoder_inputs, 880 encoder_state, 881 attention_states, 882 cell, 883 num_decoder_symbols, 884 embedding_size, 885 num_heads=num_heads, 886 output_size=output_size, 887 output_projection=output_projection, 888 feed_previous=feed_previous, 889 initial_state_attention=initial_state_attention) 890 891 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 892 def decoder(feed_previous_bool): 893 reuse = None if feed_previous_bool else True 894 with variable_scope.variable_scope( 895 variable_scope.get_variable_scope(), reuse=reuse): 896 outputs, state = embedding_attention_decoder( 897 decoder_inputs, 898 encoder_state, 899 attention_states, 900 cell, 901 num_decoder_symbols, 902 embedding_size, 903 num_heads=num_heads, 904 output_size=output_size, 905 output_projection=output_projection, 906 feed_previous=feed_previous_bool, 907 update_embedding_for_previous=False, 908 initial_state_attention=initial_state_attention) 909 state_list = [state] 910 if nest.is_sequence(state): 911 state_list = nest.flatten(state) 912 return outputs + state_list 913 914 outputs_and_state = control_flow_ops.cond(feed_previous, 915 lambda: decoder(True), 916 lambda: decoder(False)) 917 outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 918 state_list = outputs_and_state[outputs_len:] 919 state = state_list[0] 920 if nest.is_sequence(encoder_state): 921 state = nest.pack_sequence_as( 922 structure=encoder_state, flat_sequence=state_list) 923 return outputs_and_state[:outputs_len], state 924 925 926def one2many_rnn_seq2seq(encoder_inputs, 927 decoder_inputs_dict, 928 enc_cell, 929 dec_cells_dict, 930 num_encoder_symbols, 931 num_decoder_symbols_dict, 932 embedding_size, 933 feed_previous=False, 934 dtype=None, 935 scope=None): 936 """One-to-many RNN sequence-to-sequence model (multi-task). 937 938 This is a multi-task sequence-to-sequence model with one encoder and multiple 939 decoders. Reference to multi-task sequence-to-sequence learning can be found 940 here: http://arxiv.org/abs/1511.06114 941 942 Args: 943 encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 944 decoder_inputs_dict: A dictionary mapping decoder name (string) to 945 the corresponding decoder_inputs; each decoder_inputs is a list of 1D 946 Tensors of shape [batch_size]; num_decoders is defined as 947 len(decoder_inputs_dict). 948 enc_cell: tf.nn.rnn_cell.RNNCell defining the encoder cell function and 949 size. 950 dec_cells_dict: A dictionary mapping encoder name (string) to an 951 instance of tf.nn.rnn_cell.RNNCell. 952 num_encoder_symbols: Integer; number of symbols on the encoder side. 953 num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an 954 integer specifying number of symbols for the corresponding decoder; 955 len(num_decoder_symbols_dict) must be equal to num_decoders. 956 embedding_size: Integer, the length of the embedding vector for each symbol. 957 feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of 958 decoder_inputs will be used (the "GO" symbol), and all other decoder 959 inputs will be taken from previous outputs (as in embedding_rnn_decoder). 960 If False, decoder_inputs are used as given (the standard decoder case). 961 dtype: The dtype of the initial state for both the encoder and encoder 962 rnn cells (default: tf.float32). 963 scope: VariableScope for the created subgraph; defaults to 964 "one2many_rnn_seq2seq" 965 966 Returns: 967 A tuple of the form (outputs_dict, state_dict), where: 968 outputs_dict: A mapping from decoder name (string) to a list of the same 969 length as decoder_inputs_dict[name]; each element in the list is a 2D 970 Tensors with shape [batch_size x num_decoder_symbol_list[name]] 971 containing the generated outputs. 972 state_dict: A mapping from decoder name (string) to the final state of the 973 corresponding decoder RNN; it is a 2D Tensor of shape 974 [batch_size x cell.state_size]. 975 976 Raises: 977 TypeError: if enc_cell or any of the dec_cells are not instances of RNNCell. 978 ValueError: if len(dec_cells) != len(decoder_inputs_dict). 979 """ 980 outputs_dict = {} 981 state_dict = {} 982 983 if not isinstance(enc_cell, rnn_cell_impl.RNNCell): 984 raise TypeError("enc_cell is not an RNNCell: %s" % type(enc_cell)) 985 if set(dec_cells_dict) != set(decoder_inputs_dict): 986 raise ValueError("keys of dec_cells_dict != keys of decodre_inputs_dict") 987 for dec_cell in dec_cells_dict.values(): 988 if not isinstance(dec_cell, rnn_cell_impl.RNNCell): 989 raise TypeError("dec_cell is not an RNNCell: %s" % type(dec_cell)) 990 991 with variable_scope.variable_scope( 992 scope or "one2many_rnn_seq2seq", dtype=dtype) as scope: 993 dtype = scope.dtype 994 995 # Encoder. 996 enc_cell = core_rnn_cell.EmbeddingWrapper( 997 enc_cell, 998 embedding_classes=num_encoder_symbols, 999 embedding_size=embedding_size) 1000 _, encoder_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype) 1001 1002 # Decoder. 1003 for name, decoder_inputs in decoder_inputs_dict.items(): 1004 num_decoder_symbols = num_decoder_symbols_dict[name] 1005 dec_cell = dec_cells_dict[name] 1006 1007 with variable_scope.variable_scope("one2many_decoder_" + str( 1008 name)) as scope: 1009 dec_cell = core_rnn_cell.OutputProjectionWrapper( 1010 dec_cell, num_decoder_symbols) 1011 if isinstance(feed_previous, bool): 1012 outputs, state = embedding_rnn_decoder( 1013 decoder_inputs, 1014 encoder_state, 1015 dec_cell, 1016 num_decoder_symbols, 1017 embedding_size, 1018 feed_previous=feed_previous) 1019 else: 1020 # If feed_previous is a Tensor, we construct 2 graphs and use cond. 1021 def filled_embedding_rnn_decoder(feed_previous): 1022 """The current decoder with a fixed feed_previous parameter.""" 1023 # pylint: disable=cell-var-from-loop 1024 reuse = None if feed_previous else True 1025 vs = variable_scope.get_variable_scope() 1026 with variable_scope.variable_scope(vs, reuse=reuse): 1027 outputs, state = embedding_rnn_decoder( 1028 decoder_inputs, 1029 encoder_state, 1030 dec_cell, 1031 num_decoder_symbols, 1032 embedding_size, 1033 feed_previous=feed_previous) 1034 # pylint: enable=cell-var-from-loop 1035 state_list = [state] 1036 if nest.is_sequence(state): 1037 state_list = nest.flatten(state) 1038 return outputs + state_list 1039 1040 outputs_and_state = control_flow_ops.cond( 1041 feed_previous, lambda: filled_embedding_rnn_decoder(True), 1042 lambda: filled_embedding_rnn_decoder(False)) 1043 # Outputs length is the same as for decoder inputs. 1044 outputs_len = len(decoder_inputs) 1045 outputs = outputs_and_state[:outputs_len] 1046 state_list = outputs_and_state[outputs_len:] 1047 state = state_list[0] 1048 if nest.is_sequence(encoder_state): 1049 state = nest.pack_sequence_as( 1050 structure=encoder_state, flat_sequence=state_list) 1051 outputs_dict[name] = outputs 1052 state_dict[name] = state 1053 1054 return outputs_dict, state_dict 1055 1056 1057def sequence_loss_by_example(logits, 1058 targets, 1059 weights, 1060 average_across_timesteps=True, 1061 softmax_loss_function=None, 1062 name=None): 1063 """Weighted cross-entropy loss for a sequence of logits (per example). 1064 1065 Args: 1066 logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1067 targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1068 weights: List of 1D batch-sized float-Tensors of the same length as logits. 1069 average_across_timesteps: If set, divide the returned cost by the total 1070 label weight. 1071 softmax_loss_function: Function (labels, logits) -> loss-batch 1072 to be used instead of the standard softmax (the default if this is None). 1073 **Note that to avoid confusion, it is required for the function to accept 1074 named arguments.** 1075 name: Optional name for this operation, default: "sequence_loss_by_example". 1076 1077 Returns: 1078 1D batch-sized float Tensor: The log-perplexity for each sequence. 1079 1080 Raises: 1081 ValueError: If len(logits) is different from len(targets) or len(weights). 1082 """ 1083 if len(targets) != len(logits) or len(weights) != len(logits): 1084 raise ValueError("Lengths of logits, weights, and targets must be the same " 1085 "%d, %d, %d." % (len(logits), len(weights), len(targets))) 1086 with ops.name_scope(name, "sequence_loss_by_example", 1087 logits + targets + weights): 1088 log_perp_list = [] 1089 for logit, target, weight in zip(logits, targets, weights): 1090 if softmax_loss_function is None: 1091 # TODO(irving,ebrevdo): This reshape is needed because 1092 # sequence_loss_by_example is called with scalars sometimes, which 1093 # violates our general scalar strictness policy. 1094 target = array_ops.reshape(target, [-1]) 1095 crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 1096 labels=target, logits=logit) 1097 else: 1098 crossent = softmax_loss_function(labels=target, logits=logit) 1099 log_perp_list.append(crossent * weight) 1100 log_perps = math_ops.add_n(log_perp_list) 1101 if average_across_timesteps: 1102 total_size = math_ops.add_n(weights) 1103 total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 1104 log_perps /= total_size 1105 return log_perps 1106 1107 1108def sequence_loss(logits, 1109 targets, 1110 weights, 1111 average_across_timesteps=True, 1112 average_across_batch=True, 1113 softmax_loss_function=None, 1114 name=None): 1115 """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 1116 1117 Args: 1118 logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1119 targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1120 weights: List of 1D batch-sized float-Tensors of the same length as logits. 1121 average_across_timesteps: If set, divide the returned cost by the total 1122 label weight. 1123 average_across_batch: If set, divide the returned cost by the batch size. 1124 softmax_loss_function: Function (labels, logits) -> loss-batch 1125 to be used instead of the standard softmax (the default if this is None). 1126 **Note that to avoid confusion, it is required for the function to accept 1127 named arguments.** 1128 name: Optional name for this operation, defaults to "sequence_loss". 1129 1130 Returns: 1131 A scalar float Tensor: The average log-perplexity per symbol (weighted). 1132 1133 Raises: 1134 ValueError: If len(logits) is different from len(targets) or len(weights). 1135 """ 1136 with ops.name_scope(name, "sequence_loss", logits + targets + weights): 1137 cost = math_ops.reduce_sum( 1138 sequence_loss_by_example( 1139 logits, 1140 targets, 1141 weights, 1142 average_across_timesteps=average_across_timesteps, 1143 softmax_loss_function=softmax_loss_function)) 1144 if average_across_batch: 1145 batch_size = array_ops.shape(targets[0])[0] 1146 return cost / math_ops.cast(batch_size, cost.dtype) 1147 else: 1148 return cost 1149 1150 1151def model_with_buckets(encoder_inputs, 1152 decoder_inputs, 1153 targets, 1154 weights, 1155 buckets, 1156 seq2seq, 1157 softmax_loss_function=None, 1158 per_example_loss=False, 1159 name=None): 1160 """Create a sequence-to-sequence model with support for bucketing. 1161 1162 The seq2seq argument is a function that defines a sequence-to-sequence model, 1163 e.g., seq2seq = lambda x, y: basic_rnn_seq2seq( 1164 x, y, rnn_cell.GRUCell(24)) 1165 1166 Args: 1167 encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 1168 decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 1169 targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 1170 weights: List of 1D batch-sized float-Tensors to weight the targets. 1171 buckets: A list of pairs of (input size, output size) for each bucket. 1172 seq2seq: A sequence-to-sequence model function; it takes 2 input that 1173 agree with encoder_inputs and decoder_inputs, and returns a pair 1174 consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 1175 softmax_loss_function: Function (labels, logits) -> loss-batch 1176 to be used instead of the standard softmax (the default if this is None). 1177 **Note that to avoid confusion, it is required for the function to accept 1178 named arguments.** 1179 per_example_loss: Boolean. If set, the returned loss will be a batch-sized 1180 tensor of losses for each sequence in the batch. If unset, it will be 1181 a scalar with the averaged loss from all examples. 1182 name: Optional name for this operation, defaults to "model_with_buckets". 1183 1184 Returns: 1185 A tuple of the form (outputs, losses), where: 1186 outputs: The outputs for each bucket. Its j'th element consists of a list 1187 of 2D Tensors. The shape of output tensors can be either 1188 [batch_size x output_size] or [batch_size x num_decoder_symbols] 1189 depending on the seq2seq model used. 1190 losses: List of scalar Tensors, representing losses for each bucket, or, 1191 if per_example_loss is set, a list of 1D batch-sized float Tensors. 1192 1193 Raises: 1194 ValueError: If length of encoder_inputs, targets, or weights is smaller 1195 than the largest (last) bucket. 1196 """ 1197 if len(encoder_inputs) < buckets[-1][0]: 1198 raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 1199 "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 1200 if len(targets) < buckets[-1][1]: 1201 raise ValueError("Length of targets (%d) must be at least that of last " 1202 "bucket (%d)." % (len(targets), buckets[-1][1])) 1203 if len(weights) < buckets[-1][1]: 1204 raise ValueError("Length of weights (%d) must be at least that of last " 1205 "bucket (%d)." % (len(weights), buckets[-1][1])) 1206 1207 all_inputs = encoder_inputs + decoder_inputs + targets + weights 1208 losses = [] 1209 outputs = [] 1210 with ops.name_scope(name, "model_with_buckets", all_inputs): 1211 for j, bucket in enumerate(buckets): 1212 with variable_scope.variable_scope( 1213 variable_scope.get_variable_scope(), reuse=True if j > 0 else None): 1214 bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], 1215 decoder_inputs[:bucket[1]]) 1216 outputs.append(bucket_outputs) 1217 if per_example_loss: 1218 losses.append( 1219 sequence_loss_by_example( 1220 outputs[-1], 1221 targets[:bucket[1]], 1222 weights[:bucket[1]], 1223 softmax_loss_function=softmax_loss_function)) 1224 else: 1225 losses.append( 1226 sequence_loss( 1227 outputs[-1], 1228 targets[:bucket[1]], 1229 weights[:bucket[1]], 1230 softmax_loss_function=softmax_loss_function)) 1231 1232 return outputs, losses 1233