1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A powerful dynamic attention wrapper object.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import math 24 25import numpy as np 26 27from tensorflow.contrib.framework.python.framework import tensor_util 28from tensorflow.python.eager import context 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.keras import initializers 33from tensorflow.python.keras import layers 34from tensorflow.python.keras.engine import base_layer_utils 35from tensorflow.python.layers import base as layers_base 36from tensorflow.python.layers import core as layers_core 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import check_ops 39from tensorflow.python.ops import clip_ops 40from tensorflow.python.ops import functional_ops 41from tensorflow.python.ops import init_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import nn_ops 44from tensorflow.python.ops import random_ops 45from tensorflow.python.ops import rnn_cell_impl 46from tensorflow.python.ops import tensor_array_ops 47from tensorflow.python.ops import variable_scope 48from tensorflow.python.util import nest 49 50 51__all__ = [ 52 "AttentionMechanism", 53 "AttentionWrapper", 54 "AttentionWrapperState", 55 "LuongAttention", 56 "BahdanauAttention", 57 "hardmax", 58 "safe_cumprod", 59 "monotonic_attention", 60 "BahdanauMonotonicAttention", 61 "LuongMonotonicAttention", 62] 63 64 65_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access 66 67 68class AttentionMechanism(object): 69 70 @property 71 def alignments_size(self): 72 raise NotImplementedError 73 74 @property 75 def state_size(self): 76 raise NotImplementedError 77 78 79class _BaseAttentionMechanism(AttentionMechanism): 80 """A base AttentionMechanism class providing common functionality. 81 82 Common functionality includes: 83 1. Storing the query and memory layers. 84 2. Preprocessing and storing the memory. 85 """ 86 87 def __init__(self, 88 query_layer, 89 memory, 90 probability_fn, 91 memory_sequence_length=None, 92 memory_layer=None, 93 check_inner_dims_defined=True, 94 score_mask_value=None, 95 name=None): 96 """Construct base AttentionMechanism class. 97 98 Args: 99 query_layer: Callable. Instance of `tf.layers.Layer`. The layer's depth 100 must match the depth of `memory_layer`. If `query_layer` is not 101 provided, the shape of `query` must match that of `memory_layer`. 102 memory: The memory to query; usually the output of an RNN encoder. This 103 tensor should be shaped `[batch_size, max_time, ...]`. 104 probability_fn: A `callable`. Converts the score and previous alignments 105 to probabilities. Its signature should be: 106 `probabilities = probability_fn(score, state)`. 107 memory_sequence_length (optional): Sequence lengths for the batch entries 108 in memory. If provided, the memory tensor rows are masked with zeros 109 for values past the respective sequence lengths. 110 memory_layer: Instance of `tf.layers.Layer` (may be None). The layer's 111 depth must match the depth of `query_layer`. 112 If `memory_layer` is not provided, the shape of `memory` must match 113 that of `query_layer`. 114 check_inner_dims_defined: Python boolean. If `True`, the `memory` 115 argument's shape is checked to ensure all but the two outermost 116 dimensions are fully defined. 117 score_mask_value: (optional): The mask value for score before passing into 118 `probability_fn`. The default is -inf. Only used if 119 `memory_sequence_length` is not None. 120 name: Name to use when creating ops. 121 """ 122 if (query_layer is not None 123 and not isinstance(query_layer, layers_base.Layer)): 124 raise TypeError( 125 "query_layer is not a Layer: %s" % type(query_layer).__name__) 126 if (memory_layer is not None 127 and not isinstance(memory_layer, layers_base.Layer)): 128 raise TypeError( 129 "memory_layer is not a Layer: %s" % type(memory_layer).__name__) 130 self._query_layer = query_layer 131 self._memory_layer = memory_layer 132 self.dtype = memory_layer.dtype 133 if not callable(probability_fn): 134 raise TypeError("probability_fn must be callable, saw type: %s" % 135 type(probability_fn).__name__) 136 if score_mask_value is None: 137 score_mask_value = dtypes.as_dtype( 138 self._memory_layer.dtype).as_numpy_dtype(-np.inf) 139 self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda 140 probability_fn( 141 _maybe_mask_score(score, 142 memory_sequence_length=memory_sequence_length, 143 score_mask_value=score_mask_value), 144 prev)) 145 with ops.name_scope( 146 name, "BaseAttentionMechanismInit", nest.flatten(memory)): 147 self._values = _prepare_memory( 148 memory, memory_sequence_length=memory_sequence_length, 149 check_inner_dims_defined=check_inner_dims_defined) 150 self._keys = ( 151 self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable 152 else self._values) 153 self._batch_size = ( 154 tensor_shape.dimension_value(self._keys.shape[0]) or 155 array_ops.shape(self._keys)[0]) 156 self._alignments_size = (tensor_shape.dimension_value(self._keys.shape[1]) 157 or array_ops.shape(self._keys)[1]) 158 159 @property 160 def memory_layer(self): 161 return self._memory_layer 162 163 @property 164 def query_layer(self): 165 return self._query_layer 166 167 @property 168 def values(self): 169 return self._values 170 171 @property 172 def keys(self): 173 return self._keys 174 175 @property 176 def batch_size(self): 177 return self._batch_size 178 179 @property 180 def alignments_size(self): 181 return self._alignments_size 182 183 @property 184 def state_size(self): 185 return self._alignments_size 186 187 def initial_alignments(self, batch_size, dtype): 188 """Creates the initial alignment values for the `AttentionWrapper` class. 189 190 This is important for AttentionMechanisms that use the previous alignment 191 to calculate the alignment at the next time step (e.g. monotonic attention). 192 193 The default behavior is to return a tensor of all zeros. 194 195 Args: 196 batch_size: `int32` scalar, the batch_size. 197 dtype: The `dtype`. 198 199 Returns: 200 A `dtype` tensor shaped `[batch_size, alignments_size]` 201 (`alignments_size` is the values' `max_time`). 202 """ 203 max_time = self._alignments_size 204 return _zero_state_tensors(max_time, batch_size, dtype) 205 206 def initial_state(self, batch_size, dtype): 207 """Creates the initial state values for the `AttentionWrapper` class. 208 209 This is important for AttentionMechanisms that use the previous alignment 210 to calculate the alignment at the next time step (e.g. monotonic attention). 211 212 The default behavior is to return the same output as initial_alignments. 213 214 Args: 215 batch_size: `int32` scalar, the batch_size. 216 dtype: The `dtype`. 217 218 Returns: 219 A structure of all-zero tensors with shapes as described by `state_size`. 220 """ 221 return self.initial_alignments(batch_size, dtype) 222 223 224class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer): 225 """A base AttentionMechanism class providing common functionality. 226 227 Common functionality includes: 228 1. Storing the query and memory layers. 229 2. Preprocessing and storing the memory. 230 231 Note that this layer takes memory as its init parameter, which is an 232 anti-pattern of Keras API, we have to keep the memory as init parameter for 233 performance and dependency reason. Under the hood, during `__init__()`, it 234 will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let 235 keras to keep track of the memory tensor as the input of this layer. Once 236 the `__init__()` is done, then user can query the attention by 237 `score = att_obj([query, state])`, and use it as a normal keras layer. 238 239 Special attention is needed when adding using this class as the base layer for 240 new attention: 241 1. Build() could be invoked at least twice. So please make sure weights are 242 not duplicated. 243 2. Layer.get_weights() might return different set of weights if the instance 244 has `query_layer`. The query_layer weights is not initialized until the 245 memory is configured. 246 247 Also note that this layer does not work with Keras model when 248 `model.compile(run_eagerly=True)` due to the fact that this layer is stateful. 249 The support for that will be added in a future version. 250 """ 251 252 def __init__(self, 253 memory, 254 probability_fn, 255 query_layer=None, 256 memory_layer=None, 257 memory_sequence_length=None, 258 **kwargs): 259 """Construct base AttentionMechanism class. 260 261 Args: 262 memory: The memory to query; usually the output of an RNN encoder. This 263 tensor should be shaped `[batch_size, max_time, ...]`. 264 probability_fn: A `callable`. Converts the score and previous alignments 265 to probabilities. Its signature should be: 266 `probabilities = probability_fn(score, state)`. 267 query_layer: (optional): Instance of `tf.keras.Layer`. The layer's depth 268 must match the depth of `memory_layer`. If `query_layer` is not 269 provided, the shape of `query` must match that of `memory_layer`. 270 memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's 271 depth must match the depth of `query_layer`. 272 If `memory_layer` is not provided, the shape of `memory` must match 273 that of `query_layer`. 274 memory_sequence_length (optional): Sequence lengths for the batch entries 275 in memory. If provided, the memory tensor rows are masked with zeros 276 for values past the respective sequence lengths. 277 **kwargs: Dictionary that contains other common arguments for layer 278 creation. 279 """ 280 if (query_layer is not None 281 and not isinstance(query_layer, layers.Layer)): 282 raise TypeError( 283 "query_layer is not a Layer: %s" % type(query_layer).__name__) 284 if (memory_layer is not None 285 and not isinstance(memory_layer, layers.Layer)): 286 raise TypeError( 287 "memory_layer is not a Layer: %s" % type(memory_layer).__name__) 288 self.query_layer = query_layer 289 self.memory_layer = memory_layer 290 if self.memory_layer is not None and "dtype" not in kwargs: 291 kwargs["dtype"] = self.memory_layer.dtype 292 super(_BaseAttentionMechanismV2, self).__init__(**kwargs) 293 if not callable(probability_fn): 294 raise TypeError("probability_fn must be callable, saw type: %s" % 295 type(probability_fn).__name__) 296 self.probability_fn = probability_fn 297 298 self.keys = None 299 self.values = None 300 self.batch_size = None 301 self._memory_initialized = False 302 self._check_inner_dims_defined = True 303 self.supports_masking = True 304 self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf) 305 306 if memory is not None: 307 # Setup the memory by self.__call__() with memory and memory_seq_length. 308 # This will make the attention follow the keras convention which takes 309 # all the tensor inputs via __call__(). 310 if memory_sequence_length is None: 311 inputs = memory 312 else: 313 inputs = [memory, memory_sequence_length] 314 315 self.values = super(_BaseAttentionMechanismV2, self).__call__( 316 inputs, setup_memory=True) 317 318 def build(self, input_shape): 319 if not self._memory_initialized: 320 # This is for setting up the memory, which contains memory and optional 321 # memory_sequence_length. Build the memory_layer with memory shape. 322 if self.memory_layer is not None and not self.memory_layer.built: 323 if isinstance(input_shape, list): 324 self.memory_layer.build(input_shape[0]) 325 else: 326 self.memory_layer.build(input_shape) 327 else: 328 # The input_shape should be query.shape and state.shape. Use the query 329 # to init the query layer. 330 if self.query_layer is not None and not self.query_layer.built: 331 self.query_layer.build(input_shape[0]) 332 333 def __call__(self, inputs, **kwargs): 334 """Preprocess the inputs before calling `base_layer.__call__()`. 335 336 Note that there are situation here, one for setup memory, and one with 337 actual query and state. 338 1. When the memory has not been configured, we just pass all the param to 339 base_layer.__call__(), which will then invoke self.call() with proper 340 inputs, which allows this class to setup memory. 341 2. When the memory has already been setup, the input should contain query 342 and state, and optionally processed memory. If the processed memory is 343 not included in the input, we will have to append it to the inputs and 344 give it to the base_layer.__call__(). The processed memory is the output 345 of first invocation of self.__call__(). If we don't add it here, then from 346 keras perspective, the graph is disconnected since the output from 347 previous call is never used. 348 349 Args: 350 inputs: the inputs tensors. 351 **kwargs: dict, other keyeword arguments for the `__call__()` 352 """ 353 if self._memory_initialized: 354 if len(inputs) not in (2, 3): 355 raise ValueError("Expect the inputs to have 2 or 3 tensors, got %d" % 356 len(inputs)) 357 if len(inputs) == 2: 358 # We append the calculated memory here so that the graph will be 359 # connected. 360 inputs.append(self.values) 361 return super(_BaseAttentionMechanismV2, self).__call__(inputs, **kwargs) 362 363 def call(self, inputs, mask=None, setup_memory=False, **kwargs): 364 """Setup the memory or query the attention. 365 366 There are two case here, one for setup memory, and the second is query the 367 attention score. `setup_memory` is the flag to indicate which mode it is. 368 The input list will be treated differently based on that flag. 369 370 Args: 371 inputs: a list of tensor that could either be `query` and `state`, or 372 `memory` and `memory_sequence_length`. 373 `query` is the tensor of dtype matching `memory` and shape 374 `[batch_size, query_depth]`. 375 `state` is the tensor of dtype matching `memory` and shape 376 `[batch_size, alignments_size]`. (`alignments_size` is memory's 377 `max_time`). 378 `memory` is the memory to query; usually the output of an RNN encoder. 379 The tensor should be shaped `[batch_size, max_time, ...]`. 380 `memory_sequence_length` (optional) is the sequence lengths for the 381 batch entries in memory. If provided, the memory tensor rows are masked 382 with zeros for values past the respective sequence lengths. 383 mask: optional bool tensor with shape `[batch, max_time]` for the mask of 384 memory. If it is not None, the corresponding item of the memory should 385 be filtered out during calculation. 386 setup_memory: boolean, whether the input is for setting up memory, or 387 query attention. 388 **kwargs: Dict, other keyword arguments for the call method. 389 Returns: 390 Either processed memory or attention score, based on `setup_memory`. 391 """ 392 if setup_memory: 393 if isinstance(inputs, list): 394 if len(inputs) not in (1, 2): 395 raise ValueError("Expect inputs to have 1 or 2 tensors, got %d" % 396 len(inputs)) 397 memory = inputs[0] 398 memory_sequence_length = inputs[1] if len(inputs) == 2 else None 399 memory_mask = mask 400 else: 401 memory, memory_sequence_length = inputs, None 402 memory_mask = mask 403 self._setup_memory(memory, memory_sequence_length, memory_mask) 404 # We force the self.built to false here since only memory is initialized, 405 # but the real query/state has not been call() yet. The layer should be 406 # build and call again. 407 self.built = False 408 # Return the processed memory in order to create the Keras connectivity 409 # data for it. 410 return self.values 411 else: 412 if not self._memory_initialized: 413 raise ValueError("Cannot query the attention before the setup of " 414 "memory") 415 if len(inputs) not in (2, 3): 416 raise ValueError("Expect the inputs to have query, state, and optional " 417 "processed memory, got %d items" % len(inputs)) 418 # Ignore the rest of the inputs and only care about the query and state 419 query, state = inputs[0], inputs[1] 420 return self._calculate_attention(query, state) 421 422 def _setup_memory(self, memory, memory_sequence_length=None, 423 memory_mask=None): 424 """Pre-process the memory before actually query the memory. 425 426 This should only be called once at the first invocation of call(). 427 428 Args: 429 memory: The memory to query; usually the output of an RNN encoder. This 430 tensor should be shaped `[batch_size, max_time, ...]`. 431 memory_sequence_length (optional): Sequence lengths for the batch entries 432 in memory. If provided, the memory tensor rows are masked with zeros for 433 values past the respective sequence lengths. 434 memory_mask: (Optional) The boolean tensor with shape `[batch_size, 435 max_time]`. For any value equal to False, the corresponding value in 436 memory should be ignored. 437 """ 438 if self._memory_initialized: 439 raise ValueError("The memory for the attention has already been setup.") 440 if memory_sequence_length is not None and memory_mask is not None: 441 raise ValueError("memory_sequence_length and memory_mask cannot be " 442 "used at same time for attention.") 443 with ops.name_scope( 444 self.name, "BaseAttentionMechanismInit", nest.flatten(memory)): 445 self.values = _prepare_memory( 446 memory, 447 memory_sequence_length=memory_sequence_length, 448 memory_mask=memory_mask, 449 check_inner_dims_defined=self._check_inner_dims_defined) 450 # Mark the value as check since the memory and memory mask might not 451 # passed from __call__(), which does not have proper keras metadata. 452 # TODO(omalleyt): Remove this hack once the mask the has proper keras 453 # history. 454 base_layer_utils.mark_checked(self.values) 455 if self.memory_layer is not None: 456 self.keys = self.memory_layer(self.values) 457 else: 458 self.keys = self.values 459 self.batch_size = ( 460 tensor_shape.dimension_value(self.keys.shape[0]) or 461 array_ops.shape(self.keys)[0]) 462 self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1]) 463 or array_ops.shape(self.keys)[1]) 464 if memory_mask is not None: 465 unwrapped_probability_fn = self.probability_fn 466 def _mask_probability_fn(score, prev): 467 return unwrapped_probability_fn( 468 _maybe_mask_score( 469 score, 470 memory_mask=memory_mask, 471 memory_sequence_length=memory_sequence_length, 472 score_mask_value=self.score_mask_value), prev) 473 self.probability_fn = _mask_probability_fn 474 self._memory_initialized = True 475 476 def _calculate_attention(self, query, state): 477 raise NotImplementedError( 478 "_calculate_attention need to be implemented by subclasses.") 479 480 def compute_mask(self, inputs, mask=None): 481 # There real input of the attention is query and state, and the memory layer 482 # mask shouldn't be pass down. Returning None for all output mask here. 483 return None, None 484 485 def get_config(self): 486 config = {} 487 # Since the probability_fn is likely to be a wrapped function, the child 488 # class should preserve the original function and how its wrapped. 489 490 if self.query_layer is not None: 491 config["query_layer"] = { 492 "class_name": self.query_layer.__class__.__name__, 493 "config": self.query_layer.get_config(), 494 } 495 if self.memory_layer is not None: 496 config["memory_layer"] = { 497 "class_name": self.memory_layer.__class__.__name__, 498 "config": self.memory_layer.get_config(), 499 } 500 # memory is a required init parameter and its a tensor. It cannot be 501 # serialized to config, so we put a placeholder for it. 502 config["memory"] = None 503 base_config = super(_BaseAttentionMechanismV2, self).get_config() 504 return dict(list(base_config.items()) + list(config.items())) 505 506 def _process_probability_fn(self, func_name): 507 """Helper method to retrieve the probably function by string input.""" 508 valid_probability_fns = { 509 "softmax": nn_ops.softmax, 510 "hardmax": hardmax, 511 } 512 if func_name not in valid_probability_fns.keys(): 513 raise ValueError("Invalid probability function: %s, options are %s" % 514 (func_name, valid_probability_fns.keys())) 515 return valid_probability_fns[func_name] 516 517 @classmethod 518 def deserialize_inner_layer_from_config(cls, config, custom_objects): 519 """Helper method that reconstruct the query and memory from the config. 520 521 In the get_config() method, the query and memory layer configs are 522 serialized into dict for persistence, this method perform the reverse action 523 to reconstruct the layer from the config. 524 525 Args: 526 config: dict, the configs that will be used to reconstruct the object. 527 custom_objects: dict mapping class names (or function names) of custom 528 (non-Keras) objects to class/functions. 529 Returns: 530 config: dict, the config with layer instance created, which is ready to be 531 used as init parameters. 532 """ 533 # Reconstruct the query and memory layer for parent class. 534 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 535 # Instead of updating the input, create a copy and use that. 536 config = config.copy() 537 query_layer_config = config.pop("query_layer", None) 538 if query_layer_config: 539 query_layer = deserialize_layer(query_layer_config, 540 custom_objects=custom_objects) 541 config["query_layer"] = query_layer 542 memory_layer_config = config.pop("memory_layer", None) 543 if memory_layer_config: 544 memory_layer = deserialize_layer(memory_layer_config, 545 custom_objects=custom_objects) 546 config["memory_layer"] = memory_layer 547 return config 548 549 @property 550 def alignments_size(self): 551 return self._alignments_size 552 553 @property 554 def state_size(self): 555 return self._alignments_size 556 557 def initial_alignments(self, batch_size, dtype): 558 """Creates the initial alignment values for the `AttentionWrapper` class. 559 560 This is important for AttentionMechanisms that use the previous alignment 561 to calculate the alignment at the next time step (e.g. monotonic attention). 562 563 The default behavior is to return a tensor of all zeros. 564 565 Args: 566 batch_size: `int32` scalar, the batch_size. 567 dtype: The `dtype`. 568 569 Returns: 570 A `dtype` tensor shaped `[batch_size, alignments_size]` 571 (`alignments_size` is the values' `max_time`). 572 """ 573 max_time = self._alignments_size 574 return _zero_state_tensors(max_time, batch_size, dtype) 575 576 def initial_state(self, batch_size, dtype): 577 """Creates the initial state values for the `AttentionWrapper` class. 578 579 This is important for AttentionMechanisms that use the previous alignment 580 to calculate the alignment at the next time step (e.g. monotonic attention). 581 582 The default behavior is to return the same output as initial_alignments. 583 584 Args: 585 batch_size: `int32` scalar, the batch_size. 586 dtype: The `dtype`. 587 588 Returns: 589 A structure of all-zero tensors with shapes as described by `state_size`. 590 """ 591 return self.initial_alignments(batch_size, dtype) 592 593 594def _luong_score(query, keys, scale): 595 """Implements Luong-style (multiplicative) scoring function. 596 597 This attention has two forms. The first is standard Luong attention, 598 as described in: 599 600 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 601 "Effective Approaches to Attention-based Neural Machine Translation." 602 EMNLP 2015. https://arxiv.org/abs/1508.04025 603 604 The second is the scaled form inspired partly by the normalized form of 605 Bahdanau attention. 606 607 To enable the second form, call this function with `scale=True`. 608 609 Args: 610 query: Tensor, shape `[batch_size, num_units]` to compare to keys. 611 keys: Processed memory, shape `[batch_size, max_time, num_units]`. 612 scale: the optional tensor to scale the attention score. 613 614 Returns: 615 A `[batch_size, max_time]` tensor of unnormalized score values. 616 617 Raises: 618 ValueError: If `key` and `query` depths do not match. 619 """ 620 depth = query.get_shape()[-1] 621 key_units = keys.get_shape()[-1] 622 if depth != key_units: 623 raise ValueError( 624 "Incompatible or unknown inner dimensions between query and keys. " 625 "Query (%s) has units: %s. Keys (%s) have units: %s. " 626 "Perhaps you need to set num_units to the keys' dimension (%s)?" 627 % (query, depth, keys, key_units, key_units)) 628 629 # Reshape from [batch_size, depth] to [batch_size, 1, depth] 630 # for matmul. 631 query = array_ops.expand_dims(query, 1) 632 633 # Inner product along the query units dimension. 634 # matmul shapes: query is [batch_size, 1, depth] and 635 # keys is [batch_size, max_time, depth]. 636 # the inner product is asked to **transpose keys' inner shape** to get a 637 # batched matmul on: 638 # [batch_size, 1, depth] . [batch_size, depth, max_time] 639 # resulting in an output shape of: 640 # [batch_size, 1, max_time]. 641 # we then squeeze out the center singleton dimension. 642 score = math_ops.matmul(query, keys, transpose_b=True) 643 score = array_ops.squeeze(score, [1]) 644 645 if scale is not None: 646 score = scale * score 647 return score 648 649 650class LuongAttention(_BaseAttentionMechanism): 651 """Implements Luong-style (multiplicative) attention scoring. 652 653 This attention has two forms. The first is standard Luong attention, 654 as described in: 655 656 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 657 [Effective Approaches to Attention-based Neural Machine Translation. 658 EMNLP 2015.](https://arxiv.org/abs/1508.04025) 659 660 The second is the scaled form inspired partly by the normalized form of 661 Bahdanau attention. 662 663 To enable the second form, construct the object with parameter 664 `scale=True`. 665 """ 666 667 def __init__(self, 668 num_units, 669 memory, 670 memory_sequence_length=None, 671 scale=False, 672 probability_fn=None, 673 score_mask_value=None, 674 dtype=None, 675 name="LuongAttention"): 676 """Construct the AttentionMechanism mechanism. 677 678 Args: 679 num_units: The depth of the attention mechanism. 680 memory: The memory to query; usually the output of an RNN encoder. This 681 tensor should be shaped `[batch_size, max_time, ...]`. 682 memory_sequence_length: (optional) Sequence lengths for the batch entries 683 in memory. If provided, the memory tensor rows are masked with zeros 684 for values past the respective sequence lengths. 685 scale: Python boolean. Whether to scale the energy term. 686 probability_fn: (optional) A `callable`. Converts the score to 687 probabilities. The default is `tf.nn.softmax`. Other options include 688 `tf.contrib.seq2seq.hardmax` and `tf.contrib.sparsemax.sparsemax`. 689 Its signature should be: `probabilities = probability_fn(score)`. 690 score_mask_value: (optional) The mask value for score before passing into 691 `probability_fn`. The default is -inf. Only used if 692 `memory_sequence_length` is not None. 693 dtype: The data type for the memory layer of the attention mechanism. 694 name: Name to use when creating ops. 695 """ 696 # For LuongAttention, we only transform the memory layer; thus 697 # num_units **must** match expected the query depth. 698 if probability_fn is None: 699 probability_fn = nn_ops.softmax 700 if dtype is None: 701 dtype = dtypes.float32 702 wrapped_probability_fn = lambda score, _: probability_fn(score) 703 super(LuongAttention, self).__init__( 704 query_layer=None, 705 memory_layer=layers_core.Dense( 706 num_units, name="memory_layer", use_bias=False, dtype=dtype), 707 memory=memory, 708 probability_fn=wrapped_probability_fn, 709 memory_sequence_length=memory_sequence_length, 710 score_mask_value=score_mask_value, 711 name=name) 712 self._num_units = num_units 713 self._scale = scale 714 self._name = name 715 716 def __call__(self, query, state): 717 """Score the query based on the keys and values. 718 719 Args: 720 query: Tensor of dtype matching `self.values` and shape 721 `[batch_size, query_depth]`. 722 state: Tensor of dtype matching `self.values` and shape 723 `[batch_size, alignments_size]` 724 (`alignments_size` is memory's `max_time`). 725 726 Returns: 727 alignments: Tensor of dtype matching `self.values` and shape 728 `[batch_size, alignments_size]` (`alignments_size` is memory's 729 `max_time`). 730 """ 731 with variable_scope.variable_scope(None, "luong_attention", [query]): 732 attention_g = None 733 if self._scale: 734 attention_g = variable_scope.get_variable( 735 "attention_g", dtype=query.dtype, 736 initializer=init_ops.ones_initializer, shape=()) 737 score = _luong_score(query, self._keys, attention_g) 738 alignments = self._probability_fn(score, state) 739 next_state = alignments 740 return alignments, next_state 741 742 743class LuongAttentionV2(_BaseAttentionMechanismV2): 744 """Implements Luong-style (multiplicative) attention scoring. 745 746 This attention has two forms. The first is standard Luong attention, 747 as described in: 748 749 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 750 [Effective Approaches to Attention-based Neural Machine Translation. 751 EMNLP 2015.](https://arxiv.org/abs/1508.04025) 752 753 The second is the scaled form inspired partly by the normalized form of 754 Bahdanau attention. 755 756 To enable the second form, construct the object with parameter 757 `scale=True`. 758 """ 759 760 def __init__(self, 761 units, 762 memory, 763 memory_sequence_length=None, 764 scale=False, 765 probability_fn="softmax", 766 dtype=None, 767 name="LuongAttention", 768 **kwargs): 769 """Construct the AttentionMechanism mechanism. 770 771 Args: 772 units: The depth of the attention mechanism. 773 memory: The memory to query; usually the output of an RNN encoder. This 774 tensor should be shaped `[batch_size, max_time, ...]`. 775 memory_sequence_length: (optional): Sequence lengths for the batch entries 776 in memory. If provided, the memory tensor rows are masked with zeros 777 for values past the respective sequence lengths. 778 scale: Python boolean. Whether to scale the energy term. 779 probability_fn: (optional) string, the name of function to convert the 780 attention score to probabilities. The default is `softmax` which is 781 `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within 782 this module. Any other value will result intovalidation error. Default 783 to use `softmax`. 784 dtype: The data type for the memory layer of the attention mechanism. 785 name: Name to use when creating ops. 786 **kwargs: Dictionary that contains other common arguments for layer 787 creation. 788 """ 789 # For LuongAttention, we only transform the memory layer; thus 790 # num_units **must** match expected the query depth. 791 self.probability_fn_name = probability_fn 792 probability_fn = self._process_probability_fn(self.probability_fn_name) 793 wrapped_probability_fn = lambda score, _: probability_fn(score) 794 if dtype is None: 795 dtype = dtypes.float32 796 memory_layer = kwargs.pop("memory_layer", None) 797 if not memory_layer: 798 memory_layer = layers.Dense( 799 units, name="memory_layer", use_bias=False, dtype=dtype) 800 self.units = units 801 self.scale = scale 802 self.scale_weight = None 803 super(LuongAttentionV2, self).__init__( 804 memory=memory, 805 memory_sequence_length=memory_sequence_length, 806 query_layer=None, 807 memory_layer=memory_layer, 808 probability_fn=wrapped_probability_fn, 809 name=name, 810 dtype=dtype, 811 **kwargs) 812 813 def build(self, input_shape): 814 super(LuongAttentionV2, self).build(input_shape) 815 if self.scale and self.scale_weight is None: 816 self.scale_weight = self.add_weight( 817 "attention_g", initializer=init_ops.ones_initializer, shape=()) 818 self.built = True 819 820 def _calculate_attention(self, query, state): 821 """Score the query based on the keys and values. 822 823 Args: 824 query: Tensor of dtype matching `self.values` and shape 825 `[batch_size, query_depth]`. 826 state: Tensor of dtype matching `self.values` and shape 827 `[batch_size, alignments_size]` 828 (`alignments_size` is memory's `max_time`). 829 830 Returns: 831 alignments: Tensor of dtype matching `self.values` and shape 832 `[batch_size, alignments_size]` (`alignments_size` is memory's 833 `max_time`). 834 next_state: Same as the alignments. 835 """ 836 score = _luong_score(query, self.keys, self.scale_weight) 837 alignments = self.probability_fn(score, state) 838 next_state = alignments 839 return alignments, next_state 840 841 def get_config(self): 842 config = { 843 "units": self.units, 844 "scale": self.scale, 845 "probability_fn": self.probability_fn_name, 846 } 847 base_config = super(LuongAttentionV2, self).get_config() 848 return dict(list(base_config.items()) + list(config.items())) 849 850 @classmethod 851 def from_config(cls, config, custom_objects=None): 852 config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( 853 config, custom_objects=custom_objects) 854 return cls(**config) 855 856 857def _bahdanau_score(processed_query, keys, attention_v, 858 attention_g=None, attention_b=None): 859 """Implements Bahdanau-style (additive) scoring function. 860 861 This attention has two forms. The first is Bhandanau attention, 862 as described in: 863 864 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 865 "Neural Machine Translation by Jointly Learning to Align and Translate." 866 ICLR 2015. https://arxiv.org/abs/1409.0473 867 868 The second is the normalized form. This form is inspired by the 869 weight normalization article: 870 871 Tim Salimans, Diederik P. Kingma. 872 "Weight Normalization: A Simple Reparameterization to Accelerate 873 Training of Deep Neural Networks." 874 https://arxiv.org/abs/1602.07868 875 876 To enable the second form, set please pass in attention_g and attention_b. 877 878 Args: 879 processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. 880 keys: Processed memory, shape `[batch_size, max_time, num_units]`. 881 attention_v: Tensor, shape `[num_units]`. 882 attention_g: Optional scalar tensor for normalization. 883 attention_b: Optional tensor with shape `[num_units]` for normalization. 884 885 Returns: 886 A `[batch_size, max_time]` tensor of unnormalized score values. 887 """ 888 # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. 889 processed_query = array_ops.expand_dims(processed_query, 1) 890 if attention_g is not None and attention_b is not None: 891 normed_v = attention_g * attention_v * math_ops.rsqrt( 892 math_ops.reduce_sum(math_ops.square(attention_v))) 893 return math_ops.reduce_sum( 894 normed_v * math_ops.tanh(keys + processed_query + attention_b), [2]) 895 else: 896 return math_ops.reduce_sum( 897 attention_v * math_ops.tanh(keys + processed_query), [2]) 898 899 900class BahdanauAttention(_BaseAttentionMechanism): 901 """Implements Bahdanau-style (additive) attention. 902 903 This attention has two forms. The first is Bahdanau attention, 904 as described in: 905 906 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 907 "Neural Machine Translation by Jointly Learning to Align and Translate." 908 ICLR 2015. https://arxiv.org/abs/1409.0473 909 910 The second is the normalized form. This form is inspired by the 911 weight normalization article: 912 913 Tim Salimans, Diederik P. Kingma. 914 "Weight Normalization: A Simple Reparameterization to Accelerate 915 Training of Deep Neural Networks." 916 https://arxiv.org/abs/1602.07868 917 918 To enable the second form, construct the object with parameter 919 `normalize=True`. 920 """ 921 922 def __init__(self, 923 num_units, 924 memory, 925 memory_sequence_length=None, 926 normalize=False, 927 probability_fn=None, 928 score_mask_value=None, 929 dtype=None, 930 name="BahdanauAttention"): 931 """Construct the Attention mechanism. 932 933 Args: 934 num_units: The depth of the query mechanism. 935 memory: The memory to query; usually the output of an RNN encoder. This 936 tensor should be shaped `[batch_size, max_time, ...]`. 937 memory_sequence_length (optional): Sequence lengths for the batch entries 938 in memory. If provided, the memory tensor rows are masked with zeros 939 for values past the respective sequence lengths. 940 normalize: Python boolean. Whether to normalize the energy term. 941 probability_fn: (optional) A `callable`. Converts the score to 942 probabilities. The default is `tf.nn.softmax`. Other options include 943 `tf.contrib.seq2seq.hardmax` and `tf.contrib.sparsemax.sparsemax`. 944 Its signature should be: `probabilities = probability_fn(score)`. 945 score_mask_value: (optional): The mask value for score before passing into 946 `probability_fn`. The default is -inf. Only used if 947 `memory_sequence_length` is not None. 948 dtype: The data type for the query and memory layers of the attention 949 mechanism. 950 name: Name to use when creating ops. 951 """ 952 if probability_fn is None: 953 probability_fn = nn_ops.softmax 954 if dtype is None: 955 dtype = dtypes.float32 956 wrapped_probability_fn = lambda score, _: probability_fn(score) 957 super(BahdanauAttention, self).__init__( 958 query_layer=layers_core.Dense( 959 num_units, name="query_layer", use_bias=False, dtype=dtype), 960 memory_layer=layers_core.Dense( 961 num_units, name="memory_layer", use_bias=False, dtype=dtype), 962 memory=memory, 963 probability_fn=wrapped_probability_fn, 964 memory_sequence_length=memory_sequence_length, 965 score_mask_value=score_mask_value, 966 name=name) 967 self._num_units = num_units 968 self._normalize = normalize 969 self._name = name 970 971 def __call__(self, query, state): 972 """Score the query based on the keys and values. 973 974 Args: 975 query: Tensor of dtype matching `self.values` and shape 976 `[batch_size, query_depth]`. 977 state: Tensor of dtype matching `self.values` and shape 978 `[batch_size, alignments_size]` 979 (`alignments_size` is memory's `max_time`). 980 981 Returns: 982 alignments: Tensor of dtype matching `self.values` and shape 983 `[batch_size, alignments_size]` (`alignments_size` is memory's 984 `max_time`). 985 """ 986 with variable_scope.variable_scope(None, "bahdanau_attention", [query]): 987 processed_query = self.query_layer(query) if self.query_layer else query 988 attention_v = variable_scope.get_variable( 989 "attention_v", [self._num_units], dtype=query.dtype) 990 if not self._normalize: 991 attention_g = None 992 attention_b = None 993 else: 994 attention_g = variable_scope.get_variable( 995 "attention_g", dtype=query.dtype, 996 initializer=init_ops.constant_initializer( 997 math.sqrt((1. / self._num_units))), 998 shape=()) 999 attention_b = variable_scope.get_variable( 1000 "attention_b", [self._num_units], dtype=query.dtype, 1001 initializer=init_ops.zeros_initializer()) 1002 1003 score = _bahdanau_score(processed_query, self._keys, attention_v, 1004 attention_g=attention_g, attention_b=attention_b) 1005 alignments = self._probability_fn(score, state) 1006 next_state = alignments 1007 return alignments, next_state 1008 1009 1010class BahdanauAttentionV2(_BaseAttentionMechanismV2): 1011 """Implements Bahdanau-style (additive) attention. 1012 1013 This attention has two forms. The first is Bahdanau attention, 1014 as described in: 1015 1016 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 1017 "Neural Machine Translation by Jointly Learning to Align and Translate." 1018 ICLR 2015. https://arxiv.org/abs/1409.0473 1019 1020 The second is the normalized form. This form is inspired by the 1021 weight normalization article: 1022 1023 Tim Salimans, Diederik P. Kingma. 1024 "Weight Normalization: A Simple Reparameterization to Accelerate 1025 Training of Deep Neural Networks." 1026 https://arxiv.org/abs/1602.07868 1027 1028 To enable the second form, construct the object with parameter 1029 `normalize=True`. 1030 """ 1031 1032 def __init__(self, 1033 units, 1034 memory, 1035 memory_sequence_length=None, 1036 normalize=False, 1037 probability_fn="softmax", 1038 kernel_initializer="glorot_uniform", 1039 dtype=None, 1040 name="BahdanauAttention", 1041 **kwargs): 1042 """Construct the Attention mechanism. 1043 1044 Args: 1045 units: The depth of the query mechanism. 1046 memory: The memory to query; usually the output of an RNN encoder. This 1047 tensor should be shaped `[batch_size, max_time, ...]`. 1048 memory_sequence_length: (optional): Sequence lengths for the batch entries 1049 in memory. If provided, the memory tensor rows are masked with zeros 1050 for values past the respective sequence lengths. 1051 normalize: Python boolean. Whether to normalize the energy term. 1052 probability_fn: (optional) string, the name of function to convert the 1053 attention score to probabilities. The default is `softmax` which is 1054 `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within 1055 this module. Any other value will result into validation error. Default 1056 to use `softmax`. 1057 kernel_initializer: (optional), the name of the initializer for the 1058 attention kernel. 1059 dtype: The data type for the query and memory layers of the attention 1060 mechanism. 1061 name: Name to use when creating ops. 1062 **kwargs: Dictionary that contains other common arguments for layer 1063 creation. 1064 """ 1065 self.probability_fn_name = probability_fn 1066 probability_fn = self._process_probability_fn(self.probability_fn_name) 1067 wrapped_probability_fn = lambda score, _: probability_fn(score) 1068 if dtype is None: 1069 dtype = dtypes.float32 1070 query_layer = kwargs.pop("query_layer", None) 1071 if not query_layer: 1072 query_layer = layers.Dense( 1073 units, name="query_layer", use_bias=False, dtype=dtype) 1074 memory_layer = kwargs.pop("memory_layer", None) 1075 if not memory_layer: 1076 memory_layer = layers.Dense( 1077 units, name="memory_layer", use_bias=False, dtype=dtype) 1078 self.units = units 1079 self.normalize = normalize 1080 self.kernel_initializer = initializers.get(kernel_initializer) 1081 self.attention_v = None 1082 self.attention_g = None 1083 self.attention_b = None 1084 super(BahdanauAttentionV2, self).__init__( 1085 memory=memory, 1086 memory_sequence_length=memory_sequence_length, 1087 query_layer=query_layer, 1088 memory_layer=memory_layer, 1089 probability_fn=wrapped_probability_fn, 1090 name=name, 1091 dtype=dtype, 1092 **kwargs) 1093 1094 def build(self, input_shape): 1095 super(BahdanauAttentionV2, self).build(input_shape) 1096 if self.attention_v is None: 1097 self.attention_v = self.add_weight( 1098 "attention_v", [self.units], 1099 dtype=self.dtype, 1100 initializer=self.kernel_initializer) 1101 if self.normalize and self.attention_g is None and self.attention_b is None: 1102 self.attention_g = self.add_weight( 1103 "attention_g", initializer=init_ops.constant_initializer( 1104 math.sqrt((1. / self.units))), shape=()) 1105 self.attention_b = self.add_weight( 1106 "attention_b", shape=[self.units], 1107 initializer=init_ops.zeros_initializer()) 1108 self.built = True 1109 1110 def _calculate_attention(self, query, state): 1111 """Score the query based on the keys and values. 1112 1113 Args: 1114 query: Tensor of dtype matching `self.values` and shape 1115 `[batch_size, query_depth]`. 1116 state: Tensor of dtype matching `self.values` and shape 1117 `[batch_size, alignments_size]` 1118 (`alignments_size` is memory's `max_time`). 1119 1120 Returns: 1121 alignments: Tensor of dtype matching `self.values` and shape 1122 `[batch_size, alignments_size]` (`alignments_size` is memory's 1123 `max_time`). 1124 next_state: same as alignments. 1125 """ 1126 processed_query = self.query_layer(query) if self.query_layer else query 1127 score = _bahdanau_score(processed_query, self.keys, self.attention_v, 1128 attention_g=self.attention_g, 1129 attention_b=self.attention_b) 1130 alignments = self.probability_fn(score, state) 1131 next_state = alignments 1132 return alignments, next_state 1133 1134 def get_config(self): 1135 config = { 1136 "units": self.units, 1137 "normalize": self.normalize, 1138 "probability_fn": self.probability_fn_name, 1139 "kernel_initializer": initializers.serialize(self.kernel_initializer) 1140 } 1141 base_config = super(BahdanauAttentionV2, self).get_config() 1142 return dict(list(base_config.items()) + list(config.items())) 1143 1144 @classmethod 1145 def from_config(cls, config, custom_objects=None): 1146 config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( 1147 config, custom_objects=custom_objects) 1148 return cls(**config) 1149 1150 1151def safe_cumprod(x, *args, **kwargs): 1152 """Computes cumprod of x in logspace using cumsum to avoid underflow. 1153 1154 The cumprod function and its gradient can result in numerical instabilities 1155 when its argument has very small and/or zero values. As long as the argument 1156 is all positive, we can instead compute the cumulative product as 1157 exp(cumsum(log(x))). This function can be called identically to tf.cumprod. 1158 1159 Args: 1160 x: Tensor to take the cumulative product of. 1161 *args: Passed on to cumsum; these are identical to those in cumprod. 1162 **kwargs: Passed on to cumsum; these are identical to those in cumprod. 1163 Returns: 1164 Cumulative product of x. 1165 """ 1166 with ops.name_scope(None, "SafeCumprod", [x]): 1167 x = ops.convert_to_tensor(x, name="x") 1168 tiny = np.finfo(x.dtype.as_numpy_dtype).tiny 1169 return math_ops.exp(math_ops.cumsum( 1170 math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs)) 1171 1172 1173def monotonic_attention(p_choose_i, previous_attention, mode): 1174 """Compute monotonic attention distribution from choosing probabilities. 1175 1176 Monotonic attention implies that the input sequence is processed in an 1177 explicitly left-to-right manner when generating the output sequence. In 1178 addition, once an input sequence element is attended to at a given output 1179 timestep, elements occurring before it cannot be attended to at subsequent 1180 output timesteps. This function generates attention distributions according 1181 to these assumptions. For more information, see `Online and Linear-Time 1182 Attention by Enforcing Monotonic Alignments`. 1183 1184 Args: 1185 p_choose_i: Probability of choosing input sequence/memory element i. Should 1186 be of shape (batch_size, input_sequence_length), and should all be in the 1187 range [0, 1]. 1188 previous_attention: The attention distribution from the previous output 1189 timestep. Should be of shape (batch_size, input_sequence_length). For 1190 the first output timestep, preevious_attention[n] should be [1, 0, 0, ..., 1191 0] for all n in [0, ... batch_size - 1]. 1192 mode: How to compute the attention distribution. Must be one of 1193 'recursive', 'parallel', or 'hard'. 1194 * 'recursive' uses tf.scan to recursively compute the distribution. 1195 This is slowest but is exact, general, and does not suffer from 1196 numerical instabilities. 1197 * 'parallel' uses parallelized cumulative-sum and cumulative-product 1198 operations to compute a closed-form solution to the recurrence 1199 relation defining the attention distribution. This makes it more 1200 efficient than 'recursive', but it requires numerical checks which 1201 make the distribution non-exact. This can be a problem in particular 1202 when input_sequence_length is long and/or p_choose_i has entries very 1203 close to 0 or 1. 1204 * 'hard' requires that the probabilities in p_choose_i are all either 0 1205 or 1, and subsequently uses a more efficient and exact solution. 1206 1207 Returns: 1208 A tensor of shape (batch_size, input_sequence_length) representing the 1209 attention distributions for each sequence in the batch. 1210 1211 Raises: 1212 ValueError: mode is not one of 'recursive', 'parallel', 'hard'. 1213 """ 1214 # Force things to be tensors 1215 p_choose_i = ops.convert_to_tensor(p_choose_i, name="p_choose_i") 1216 previous_attention = ops.convert_to_tensor( 1217 previous_attention, name="previous_attention") 1218 if mode == "recursive": 1219 # Use .shape[0] when it's not None, or fall back on symbolic shape 1220 batch_size = tensor_shape.dimension_value( 1221 p_choose_i.shape[0]) or array_ops.shape(p_choose_i)[0] 1222 # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]] 1223 shifted_1mp_choose_i = array_ops.concat( 1224 [array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1) 1225 # Compute attention distribution recursively as 1226 # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i] 1227 # attention[i] = p_choose_i[i]*q[i] 1228 attention = p_choose_i*array_ops.transpose(functional_ops.scan( 1229 # Need to use reshape to remind TF of the shape between loop iterations 1230 lambda x, yz: array_ops.reshape(yz[0]*x + yz[1], (batch_size,)), 1231 # Loop variables yz[0] and yz[1] 1232 [array_ops.transpose(shifted_1mp_choose_i), 1233 array_ops.transpose(previous_attention)], 1234 # Initial value of x is just zeros 1235 array_ops.zeros((batch_size,)))) 1236 elif mode == "parallel": 1237 # safe_cumprod computes cumprod in logspace with numeric checks 1238 cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True) 1239 # Compute recurrence relation solution 1240 attention = p_choose_i*cumprod_1mp_choose_i*math_ops.cumsum( 1241 previous_attention / 1242 # Clip cumprod_1mp to avoid divide-by-zero 1243 clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), axis=1) 1244 elif mode == "hard": 1245 # Remove any probabilities before the index chosen last time step 1246 p_choose_i *= math_ops.cumsum(previous_attention, axis=1) 1247 # Now, use exclusive cumprod to remove probabilities after the first 1248 # chosen index, like so: 1249 # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1] 1250 # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0] 1251 # Product of above: [0, 0, 0, 1, 0, 0, 0, 0] 1252 attention = p_choose_i*math_ops.cumprod( 1253 1 - p_choose_i, axis=1, exclusive=True) 1254 else: 1255 raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.") 1256 return attention 1257 1258 1259def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, 1260 seed=None): 1261 """Attention probability function for monotonic attention. 1262 1263 Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage 1264 the model to make discrete attention decisions, passes them through a sigmoid 1265 to obtain "choosing" probabilities, and then calls monotonic_attention to 1266 obtain the attention distribution. For more information, see 1267 1268 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 1269 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 1270 ICML 2017. https://arxiv.org/abs/1704.00784 1271 1272 Args: 1273 score: Unnormalized attention scores, shape `[batch_size, alignments_size]` 1274 previous_alignments: Previous attention distribution, shape 1275 `[batch_size, alignments_size]` 1276 sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this larger 1277 than 0 will encourage the model to produce large attention scores, 1278 effectively making the choosing probabilities discrete and the resulting 1279 attention distribution one-hot. It should be set to 0 at test-time, and 1280 when hard attention is not desired. 1281 mode: How to compute the attention distribution. Must be one of 1282 'recursive', 'parallel', or 'hard'. See the docstring for 1283 `tf.contrib.seq2seq.monotonic_attention` for more information. 1284 seed: (optional) Random seed for pre-sigmoid noise. 1285 1286 Returns: 1287 A `[batch_size, alignments_size]`-shape tensor corresponding to the 1288 resulting attention distribution. 1289 """ 1290 # Optionally add pre-sigmoid noise to the scores 1291 if sigmoid_noise > 0: 1292 noise = random_ops.random_normal(array_ops.shape(score), dtype=score.dtype, 1293 seed=seed) 1294 score += sigmoid_noise*noise 1295 # Compute "choosing" probabilities from the attention scores 1296 if mode == "hard": 1297 # When mode is hard, use a hard sigmoid 1298 p_choose_i = math_ops.cast(score > 0, score.dtype) 1299 else: 1300 p_choose_i = math_ops.sigmoid(score) 1301 # Convert from choosing probabilities to attention distribution 1302 return monotonic_attention(p_choose_i, previous_alignments, mode) 1303 1304 1305class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): 1306 """Base attention mechanism for monotonic attention. 1307 1308 Simply overrides the initial_alignments function to provide a dirac 1309 distribution, which is needed in order for the monotonic attention 1310 distributions to have the correct behavior. 1311 """ 1312 1313 def initial_alignments(self, batch_size, dtype): 1314 """Creates the initial alignment values for the monotonic attentions. 1315 1316 Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] 1317 for all entries in the batch. 1318 1319 Args: 1320 batch_size: `int32` scalar, the batch_size. 1321 dtype: The `dtype`. 1322 1323 Returns: 1324 A `dtype` tensor shaped `[batch_size, alignments_size]` 1325 (`alignments_size` is the values' `max_time`). 1326 """ 1327 max_time = self._alignments_size 1328 return array_ops.one_hot( 1329 array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, 1330 dtype=dtype) 1331 1332 1333class _BaseMonotonicAttentionMechanismV2(_BaseAttentionMechanismV2): 1334 """Base attention mechanism for monotonic attention. 1335 1336 Simply overrides the initial_alignments function to provide a dirac 1337 distribution, which is needed in order for the monotonic attention 1338 distributions to have the correct behavior. 1339 """ 1340 1341 def initial_alignments(self, batch_size, dtype): 1342 """Creates the initial alignment values for the monotonic attentions. 1343 1344 Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] 1345 for all entries in the batch. 1346 1347 Args: 1348 batch_size: `int32` scalar, the batch_size. 1349 dtype: The `dtype`. 1350 1351 Returns: 1352 A `dtype` tensor shaped `[batch_size, alignments_size]` 1353 (`alignments_size` is the values' `max_time`). 1354 """ 1355 max_time = self._alignments_size 1356 return array_ops.one_hot( 1357 array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, 1358 dtype=dtype) 1359 1360 1361class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): 1362 """Monotonic attention mechanism with Bahadanau-style energy function. 1363 1364 This type of attention enforces a monotonic constraint on the attention 1365 distributions; that is once the model attends to a given point in the memory 1366 it can't attend to any prior points at subsequence output timesteps. It 1367 achieves this by using the _monotonic_probability_fn instead of softmax to 1368 construct its attention distributions. Since the attention scores are passed 1369 through a sigmoid, a learnable scalar bias parameter is applied after the 1370 score function and before the sigmoid. Otherwise, it is equivalent to 1371 BahdanauAttention. This approach is proposed in 1372 1373 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 1374 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 1375 ICML 2017. https://arxiv.org/abs/1704.00784 1376 """ 1377 1378 def __init__(self, 1379 num_units, 1380 memory, 1381 memory_sequence_length=None, 1382 normalize=False, 1383 score_mask_value=None, 1384 sigmoid_noise=0., 1385 sigmoid_noise_seed=None, 1386 score_bias_init=0., 1387 mode="parallel", 1388 dtype=None, 1389 name="BahdanauMonotonicAttention"): 1390 """Construct the Attention mechanism. 1391 1392 Args: 1393 num_units: The depth of the query mechanism. 1394 memory: The memory to query; usually the output of an RNN encoder. This 1395 tensor should be shaped `[batch_size, max_time, ...]`. 1396 memory_sequence_length (optional): Sequence lengths for the batch entries 1397 in memory. If provided, the memory tensor rows are masked with zeros 1398 for values past the respective sequence lengths. 1399 normalize: Python boolean. Whether to normalize the energy term. 1400 score_mask_value: (optional): The mask value for score before passing into 1401 `probability_fn`. The default is -inf. Only used if 1402 `memory_sequence_length` is not None. 1403 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring 1404 for `_monotonic_probability_fn` for more information. 1405 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 1406 score_bias_init: Initial value for score bias scalar. It's recommended to 1407 initialize this to a negative value when the length of the memory is 1408 large. 1409 mode: How to compute the attention distribution. Must be one of 1410 'recursive', 'parallel', or 'hard'. See the docstring for 1411 `tf.contrib.seq2seq.monotonic_attention` for more information. 1412 dtype: The data type for the query and memory layers of the attention 1413 mechanism. 1414 name: Name to use when creating ops. 1415 """ 1416 # Set up the monotonic probability fn with supplied parameters 1417 if dtype is None: 1418 dtype = dtypes.float32 1419 wrapped_probability_fn = functools.partial( 1420 _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, 1421 seed=sigmoid_noise_seed) 1422 super(BahdanauMonotonicAttention, self).__init__( 1423 query_layer=layers_core.Dense( 1424 num_units, name="query_layer", use_bias=False, dtype=dtype), 1425 memory_layer=layers_core.Dense( 1426 num_units, name="memory_layer", use_bias=False, dtype=dtype), 1427 memory=memory, 1428 probability_fn=wrapped_probability_fn, 1429 memory_sequence_length=memory_sequence_length, 1430 score_mask_value=score_mask_value, 1431 name=name) 1432 self._num_units = num_units 1433 self._normalize = normalize 1434 self._name = name 1435 self._score_bias_init = score_bias_init 1436 1437 def __call__(self, query, state): 1438 """Score the query based on the keys and values. 1439 1440 Args: 1441 query: Tensor of dtype matching `self.values` and shape 1442 `[batch_size, query_depth]`. 1443 state: Tensor of dtype matching `self.values` and shape 1444 `[batch_size, alignments_size]` 1445 (`alignments_size` is memory's `max_time`). 1446 1447 Returns: 1448 alignments: Tensor of dtype matching `self.values` and shape 1449 `[batch_size, alignments_size]` (`alignments_size` is memory's 1450 `max_time`). 1451 """ 1452 with variable_scope.variable_scope( 1453 None, "bahdanau_monotonic_attention", [query]): 1454 processed_query = self.query_layer(query) if self.query_layer else query 1455 attention_v = variable_scope.get_variable( 1456 "attention_v", [self._num_units], dtype=query.dtype) 1457 if not self._normalize: 1458 attention_g = None 1459 attention_b = None 1460 else: 1461 attention_g = variable_scope.get_variable( 1462 "attention_g", dtype=query.dtype, 1463 initializer=init_ops.constant_initializer( 1464 math.sqrt((1. / self._num_units))), 1465 shape=()) 1466 attention_b = variable_scope.get_variable( 1467 "attention_b", [self._num_units], dtype=query.dtype, 1468 initializer=init_ops.zeros_initializer()) 1469 score = _bahdanau_score(processed_query, self._keys, attention_v, 1470 attention_g=attention_g, attention_b=attention_b) 1471 score_bias = variable_scope.get_variable( 1472 "attention_score_bias", dtype=processed_query.dtype, 1473 initializer=self._score_bias_init) 1474 score += score_bias 1475 alignments = self._probability_fn(score, state) 1476 next_state = alignments 1477 return alignments, next_state 1478 1479 1480class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): 1481 """Monotonic attention mechanism with Bahadanau-style energy function. 1482 1483 This type of attention enforces a monotonic constraint on the attention 1484 distributions; that is once the model attends to a given point in the memory 1485 it can't attend to any prior points at subsequence output timesteps. It 1486 achieves this by using the _monotonic_probability_fn instead of softmax to 1487 construct its attention distributions. Since the attention scores are passed 1488 through a sigmoid, a learnable scalar bias parameter is applied after the 1489 score function and before the sigmoid. Otherwise, it is equivalent to 1490 BahdanauAttention. This approach is proposed in 1491 1492 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 1493 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 1494 ICML 2017. https://arxiv.org/abs/1704.00784 1495 """ 1496 1497 def __init__(self, 1498 units, 1499 memory, 1500 memory_sequence_length=None, 1501 normalize=False, 1502 sigmoid_noise=0., 1503 sigmoid_noise_seed=None, 1504 score_bias_init=0., 1505 mode="parallel", 1506 kernel_initializer="glorot_uniform", 1507 dtype=None, 1508 name="BahdanauMonotonicAttention", 1509 **kwargs): 1510 """Construct the Attention mechanism. 1511 1512 Args: 1513 units: The depth of the query mechanism. 1514 memory: The memory to query; usually the output of an RNN encoder. This 1515 tensor should be shaped `[batch_size, max_time, ...]`. 1516 memory_sequence_length: (optional): Sequence lengths for the batch entries 1517 in memory. If provided, the memory tensor rows are masked with zeros 1518 for values past the respective sequence lengths. 1519 normalize: Python boolean. Whether to normalize the energy term. 1520 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring 1521 for `_monotonic_probability_fn` for more information. 1522 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 1523 score_bias_init: Initial value for score bias scalar. It's recommended to 1524 initialize this to a negative value when the length of the memory is 1525 large. 1526 mode: How to compute the attention distribution. Must be one of 1527 'recursive', 'parallel', or 'hard'. See the docstring for 1528 `tf.contrib.seq2seq.monotonic_attention` for more information. 1529 kernel_initializer: (optional), the name of the initializer for the 1530 attention kernel. 1531 dtype: The data type for the query and memory layers of the attention 1532 mechanism. 1533 name: Name to use when creating ops. 1534 **kwargs: Dictionary that contains other common arguments for layer 1535 creation. 1536 """ 1537 # Set up the monotonic probability fn with supplied parameters 1538 if dtype is None: 1539 dtype = dtypes.float32 1540 wrapped_probability_fn = functools.partial( 1541 _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, 1542 seed=sigmoid_noise_seed) 1543 query_layer = kwargs.pop("query_layer", None) 1544 if not query_layer: 1545 query_layer = layers.Dense( 1546 units, name="query_layer", use_bias=False, dtype=dtype) 1547 memory_layer = kwargs.pop("memory_layer", None) 1548 if not memory_layer: 1549 memory_layer = layers.Dense( 1550 units, name="memory_layer", use_bias=False, dtype=dtype) 1551 self.units = units 1552 self.normalize = normalize 1553 self.sigmoid_noise = sigmoid_noise 1554 self.sigmoid_noise_seed = sigmoid_noise_seed 1555 self.score_bias_init = score_bias_init 1556 self.mode = mode 1557 self.kernel_initializer = initializers.get(kernel_initializer) 1558 self.attention_v = None 1559 self.attention_score_bias = None 1560 self.attention_g = None 1561 self.attention_b = None 1562 super(BahdanauMonotonicAttentionV2, self).__init__( 1563 memory=memory, 1564 memory_sequence_length=memory_sequence_length, 1565 query_layer=query_layer, 1566 memory_layer=memory_layer, 1567 probability_fn=wrapped_probability_fn, 1568 name=name, 1569 dtype=dtype, 1570 **kwargs) 1571 1572 def build(self, input_shape): 1573 super(BahdanauMonotonicAttentionV2, self).build(input_shape) 1574 if self.attention_v is None: 1575 self.attention_v = self.add_weight( 1576 "attention_v", [self.units], dtype=self.dtype, 1577 initializer=self.kernel_initializer) 1578 if self.attention_score_bias is None: 1579 self.attention_score_bias = self.add_weight( 1580 "attention_score_bias", shape=(), dtype=self.dtype, 1581 initializer=init_ops.constant_initializer( 1582 self.score_bias_init, dtype=self.dtype)) 1583 if self.normalize and self.attention_g is None and self.attention_b is None: 1584 self.attention_g = self.add_weight( 1585 "attention_g", dtype=self.dtype, 1586 initializer=init_ops.constant_initializer( 1587 math.sqrt((1. / self.units))), 1588 shape=()) 1589 self.attention_b = self.add_weight( 1590 "attention_b", [self.units], dtype=self.dtype, 1591 initializer=init_ops.zeros_initializer()) 1592 self.built = True 1593 1594 def _calculate_attention(self, query, state): 1595 """Score the query based on the keys and values. 1596 1597 Args: 1598 query: Tensor of dtype matching `self.values` and shape 1599 `[batch_size, query_depth]`. 1600 state: Tensor of dtype matching `self.values` and shape 1601 `[batch_size, alignments_size]` 1602 (`alignments_size` is memory's `max_time`). 1603 1604 Returns: 1605 alignments: Tensor of dtype matching `self.values` and shape 1606 `[batch_size, alignments_size]` (`alignments_size` is memory's 1607 `max_time`). 1608 """ 1609 processed_query = self.query_layer(query) if self.query_layer else query 1610 score = _bahdanau_score(processed_query, self.keys, self.attention_v, 1611 attention_g=self.attention_g, 1612 attention_b=self.attention_b) 1613 score += self.attention_score_bias 1614 alignments = self.probability_fn(score, state) 1615 next_state = alignments 1616 return alignments, next_state 1617 1618 def get_config(self): 1619 config = { 1620 "units": self.units, 1621 "normalize": self.normalize, 1622 "sigmoid_noise": self.sigmoid_noise, 1623 "sigmoid_noise_seed": self.sigmoid_noise_seed, 1624 "score_bias_init": self.score_bias_init, 1625 "mode": self.mode, 1626 "kernel_initializer": initializers.serialize(self.kernel_initializer), 1627 } 1628 base_config = super(BahdanauMonotonicAttentionV2, self).get_config() 1629 return dict(list(base_config.items()) + list(config.items())) 1630 1631 @classmethod 1632 def from_config(cls, config, custom_objects=None): 1633 config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( 1634 config, custom_objects=custom_objects) 1635 return cls(**config) 1636 1637 1638class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): 1639 """Monotonic attention mechanism with Luong-style energy function. 1640 1641 This type of attention enforces a monotonic constraint on the attention 1642 distributions; that is once the model attends to a given point in the memory 1643 it can't attend to any prior points at subsequence output timesteps. It 1644 achieves this by using the _monotonic_probability_fn instead of softmax to 1645 construct its attention distributions. Otherwise, it is equivalent to 1646 LuongAttention. This approach is proposed in 1647 1648 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 1649 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 1650 ICML 2017. https://arxiv.org/abs/1704.00784 1651 """ 1652 1653 def __init__(self, 1654 num_units, 1655 memory, 1656 memory_sequence_length=None, 1657 scale=False, 1658 score_mask_value=None, 1659 sigmoid_noise=0., 1660 sigmoid_noise_seed=None, 1661 score_bias_init=0., 1662 mode="parallel", 1663 dtype=None, 1664 name="LuongMonotonicAttention"): 1665 """Construct the Attention mechanism. 1666 1667 Args: 1668 num_units: The depth of the query mechanism. 1669 memory: The memory to query; usually the output of an RNN encoder. This 1670 tensor should be shaped `[batch_size, max_time, ...]`. 1671 memory_sequence_length (optional): Sequence lengths for the batch entries 1672 in memory. If provided, the memory tensor rows are masked with zeros 1673 for values past the respective sequence lengths. 1674 scale: Python boolean. Whether to scale the energy term. 1675 score_mask_value: (optional): The mask value for score before passing into 1676 `probability_fn`. The default is -inf. Only used if 1677 `memory_sequence_length` is not None. 1678 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring 1679 for `_monotonic_probability_fn` for more information. 1680 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 1681 score_bias_init: Initial value for score bias scalar. It's recommended to 1682 initialize this to a negative value when the length of the memory is 1683 large. 1684 mode: How to compute the attention distribution. Must be one of 1685 'recursive', 'parallel', or 'hard'. See the docstring for 1686 `tf.contrib.seq2seq.monotonic_attention` for more information. 1687 dtype: The data type for the query and memory layers of the attention 1688 mechanism. 1689 name: Name to use when creating ops. 1690 """ 1691 # Set up the monotonic probability fn with supplied parameters 1692 if dtype is None: 1693 dtype = dtypes.float32 1694 wrapped_probability_fn = functools.partial( 1695 _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, 1696 seed=sigmoid_noise_seed) 1697 super(LuongMonotonicAttention, self).__init__( 1698 query_layer=None, 1699 memory_layer=layers_core.Dense( 1700 num_units, name="memory_layer", use_bias=False, dtype=dtype), 1701 memory=memory, 1702 probability_fn=wrapped_probability_fn, 1703 memory_sequence_length=memory_sequence_length, 1704 score_mask_value=score_mask_value, 1705 name=name) 1706 self._num_units = num_units 1707 self._scale = scale 1708 self._score_bias_init = score_bias_init 1709 self._name = name 1710 1711 def __call__(self, query, state): 1712 """Score the query based on the keys and values. 1713 1714 Args: 1715 query: Tensor of dtype matching `self.values` and shape 1716 `[batch_size, query_depth]`. 1717 state: Tensor of dtype matching `self.values` and shape 1718 `[batch_size, alignments_size]` 1719 (`alignments_size` is memory's `max_time`). 1720 1721 Returns: 1722 alignments: Tensor of dtype matching `self.values` and shape 1723 `[batch_size, alignments_size]` (`alignments_size` is memory's 1724 `max_time`). 1725 """ 1726 with variable_scope.variable_scope(None, "luong_monotonic_attention", 1727 [query]): 1728 attention_g = None 1729 if self._scale: 1730 attention_g = variable_scope.get_variable( 1731 "attention_g", dtype=query.dtype, 1732 initializer=init_ops.ones_initializer, shape=()) 1733 score = _luong_score(query, self._keys, attention_g) 1734 score_bias = variable_scope.get_variable( 1735 "attention_score_bias", dtype=query.dtype, 1736 initializer=self._score_bias_init) 1737 score += score_bias 1738 alignments = self._probability_fn(score, state) 1739 next_state = alignments 1740 return alignments, next_state 1741 1742 1743class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2): 1744 """Monotonic attention mechanism with Luong-style energy function. 1745 1746 This type of attention enforces a monotonic constraint on the attention 1747 distributions; that is once the model attends to a given point in the memory 1748 it can't attend to any prior points at subsequence output timesteps. It 1749 achieves this by using the _monotonic_probability_fn instead of softmax to 1750 construct its attention distributions. Otherwise, it is equivalent to 1751 LuongAttention. This approach is proposed in 1752 1753 [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 1754 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 1755 ICML 2017.](https://arxiv.org/abs/1704.00784) 1756 """ 1757 1758 def __init__(self, 1759 units, 1760 memory, 1761 memory_sequence_length=None, 1762 scale=False, 1763 sigmoid_noise=0., 1764 sigmoid_noise_seed=None, 1765 score_bias_init=0., 1766 mode="parallel", 1767 dtype=None, 1768 name="LuongMonotonicAttention", 1769 **kwargs): 1770 """Construct the Attention mechanism. 1771 1772 Args: 1773 units: The depth of the query mechanism. 1774 memory: The memory to query; usually the output of an RNN encoder. This 1775 tensor should be shaped `[batch_size, max_time, ...]`. 1776 memory_sequence_length: (optional): Sequence lengths for the batch entries 1777 in memory. If provided, the memory tensor rows are masked with zeros 1778 for values past the respective sequence lengths. 1779 scale: Python boolean. Whether to scale the energy term. 1780 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring 1781 for `_monotonic_probability_fn` for more information. 1782 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 1783 score_bias_init: Initial value for score bias scalar. It's recommended to 1784 initialize this to a negative value when the length of the memory is 1785 large. 1786 mode: How to compute the attention distribution. Must be one of 1787 'recursive', 'parallel', or 'hard'. See the docstring for 1788 `tf.contrib.seq2seq.monotonic_attention` for more information. 1789 dtype: The data type for the query and memory layers of the attention 1790 mechanism. 1791 name: Name to use when creating ops. 1792 **kwargs: Dictionary that contains other common arguments for layer 1793 creation. 1794 """ 1795 # Set up the monotonic probability fn with supplied parameters 1796 if dtype is None: 1797 dtype = dtypes.float32 1798 wrapped_probability_fn = functools.partial( 1799 _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, 1800 seed=sigmoid_noise_seed) 1801 memory_layer = kwargs.pop("memory_layer", None) 1802 if not memory_layer: 1803 memory_layer = layers.Dense( 1804 units, name="memory_layer", use_bias=False, dtype=dtype) 1805 self.units = units 1806 self.scale = scale 1807 self.sigmoid_noise = sigmoid_noise 1808 self.sigmoid_noise_seed = sigmoid_noise_seed 1809 self.score_bias_init = score_bias_init 1810 self.mode = mode 1811 self.attention_g = None 1812 self.attention_score_bias = None 1813 super(LuongMonotonicAttentionV2, self).__init__( 1814 memory=memory, 1815 memory_sequence_length=memory_sequence_length, 1816 query_layer=None, 1817 memory_layer=memory_layer, 1818 probability_fn=wrapped_probability_fn, 1819 name=name, 1820 dtype=dtype, 1821 **kwargs) 1822 1823 def build(self, input_shape): 1824 super(LuongMonotonicAttentionV2, self).build(input_shape) 1825 if self.scale and self.attention_g is None: 1826 self.attention_g = self.add_weight( 1827 "attention_g", initializer=init_ops.ones_initializer, shape=()) 1828 if self.attention_score_bias is None: 1829 self.attention_score_bias = self.add_weight( 1830 "attention_score_bias", shape=(), 1831 initializer=init_ops.constant_initializer( 1832 self.score_bias_init, dtype=self.dtype)) 1833 self.built = True 1834 1835 def _calculate_attention(self, query, state): 1836 """Score the query based on the keys and values. 1837 1838 Args: 1839 query: Tensor of dtype matching `self.values` and shape 1840 `[batch_size, query_depth]`. 1841 state: Tensor of dtype matching `self.values` and shape 1842 `[batch_size, alignments_size]` 1843 (`alignments_size` is memory's `max_time`). 1844 1845 Returns: 1846 alignments: Tensor of dtype matching `self.values` and shape 1847 `[batch_size, alignments_size]` (`alignments_size` is memory's 1848 `max_time`). 1849 next_state: Same as alignments 1850 """ 1851 score = _luong_score(query, self.keys, self.attention_g) 1852 score += self.attention_score_bias 1853 alignments = self.probability_fn(score, state) 1854 next_state = alignments 1855 return alignments, next_state 1856 1857 def get_config(self): 1858 config = { 1859 "units": self.units, 1860 "scale": self.scale, 1861 "sigmoid_noise": self.sigmoid_noise, 1862 "sigmoid_noise_seed": self.sigmoid_noise_seed, 1863 "score_bias_init": self.score_bias_init, 1864 "mode": self.mode, 1865 } 1866 base_config = super(LuongMonotonicAttentionV2, self).get_config() 1867 return dict(list(base_config.items()) + list(config.items())) 1868 1869 @classmethod 1870 def from_config(cls, config, custom_objects=None): 1871 config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config( 1872 config, custom_objects=custom_objects) 1873 return cls(**config) 1874 1875 1876class AttentionWrapperState( 1877 collections.namedtuple("AttentionWrapperState", 1878 ("cell_state", "attention", "time", "alignments", 1879 "alignment_history", "attention_state"))): 1880 """`namedtuple` storing the state of a `AttentionWrapper`. 1881 1882 Contains: 1883 1884 - `cell_state`: The state of the wrapped `RNNCell` at the previous time 1885 step. 1886 - `attention`: The attention emitted at the previous time step. 1887 - `time`: int32 scalar containing the current time step. 1888 - `alignments`: A single or tuple of `Tensor`(s) containing the alignments 1889 emitted at the previous time step for each attention mechanism. 1890 - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s) 1891 containing alignment matrices from all time steps for each attention 1892 mechanism. Call `stack()` on each to convert to a `Tensor`. 1893 - `attention_state`: A single or tuple of nested objects 1894 containing attention mechanism state for each attention mechanism. 1895 The objects may contain Tensors or TensorArrays. 1896 """ 1897 1898 def clone(self, **kwargs): 1899 """Clone this object, overriding components provided by kwargs. 1900 1901 The new state fields' shape must match original state fields' shape. This 1902 will be validated, and original fields' shape will be propagated to new 1903 fields. 1904 1905 Example: 1906 1907 ```python 1908 initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...) 1909 initial_state = initial_state.clone(cell_state=encoder_state) 1910 ``` 1911 1912 Args: 1913 **kwargs: Any properties of the state object to replace in the returned 1914 `AttentionWrapperState`. 1915 1916 Returns: 1917 A new `AttentionWrapperState` whose properties are the same as 1918 this one, except any overridden properties as provided in `kwargs`. 1919 """ 1920 def with_same_shape(old, new): 1921 """Check and set new tensor's shape.""" 1922 if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): 1923 if not context.executing_eagerly(): 1924 return tensor_util.with_same_shape(old, new) 1925 else: 1926 if old.shape.as_list() != new.shape.as_list(): 1927 raise ValueError("The shape of the AttentionWrapperState is " 1928 "expected to be same as the one to clone. " 1929 "self.shape: %s, input.shape: %s" % 1930 (old.shape, new.shape)) 1931 return new 1932 return new 1933 1934 return nest.map_structure( 1935 with_same_shape, 1936 self, 1937 super(AttentionWrapperState, self)._replace(**kwargs)) 1938 1939 1940def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None, 1941 check_inner_dims_defined=True): 1942 """Convert to tensor and possibly mask `memory`. 1943 1944 Args: 1945 memory: `Tensor`, shaped `[batch_size, max_time, ...]`. 1946 memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. 1947 memory_mask: `boolean` tensor with shape [batch_size, max_time]. The memory 1948 should be skipped when the corresponding mask is False. 1949 check_inner_dims_defined: Python boolean. If `True`, the `memory` 1950 argument's shape is checked to ensure all but the two outermost 1951 dimensions are fully defined. 1952 1953 Returns: 1954 A (possibly masked), checked, new `memory`. 1955 1956 Raises: 1957 ValueError: If `check_inner_dims_defined` is `True` and not 1958 `memory.shape[2:].is_fully_defined()`. 1959 """ 1960 memory = nest.map_structure( 1961 lambda m: ops.convert_to_tensor(m, name="memory"), memory) 1962 if memory_sequence_length is not None and memory_mask is not None: 1963 raise ValueError("memory_sequence_length and memory_mask can't be provided " 1964 "at same time.") 1965 if memory_sequence_length is not None: 1966 memory_sequence_length = ops.convert_to_tensor( 1967 memory_sequence_length, name="memory_sequence_length") 1968 if check_inner_dims_defined: 1969 def _check_dims(m): 1970 if not m.get_shape()[2:].is_fully_defined(): 1971 raise ValueError("Expected memory %s to have fully defined inner dims, " 1972 "but saw shape: %s" % (m.name, m.get_shape())) 1973 nest.map_structure(_check_dims, memory) 1974 if memory_sequence_length is None and memory_mask is None: 1975 return memory 1976 elif memory_sequence_length is not None: 1977 seq_len_mask = array_ops.sequence_mask( 1978 memory_sequence_length, 1979 maxlen=array_ops.shape(nest.flatten(memory)[0])[1], 1980 dtype=nest.flatten(memory)[0].dtype) 1981 else: 1982 # For memory_mask is not None 1983 seq_len_mask = math_ops.cast( 1984 memory_mask, dtype=nest.flatten(memory)[0].dtype) 1985 def _maybe_mask(m, seq_len_mask): 1986 """Mask the memory based on the memory mask.""" 1987 rank = m.get_shape().ndims 1988 rank = rank if rank is not None else array_ops.rank(m) 1989 extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) 1990 seq_len_mask = array_ops.reshape( 1991 seq_len_mask, 1992 array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) 1993 return m * seq_len_mask 1994 1995 return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) 1996 1997 1998def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None, 1999 score_mask_value=None): 2000 """Mask the attention score based on the masks.""" 2001 if memory_sequence_length is None and memory_mask is None: 2002 return score 2003 if memory_sequence_length is not None and memory_mask is not None: 2004 raise ValueError("memory_sequence_length and memory_mask can't be provided " 2005 "at same time.") 2006 if memory_sequence_length is not None: 2007 message = "All values in memory_sequence_length must greater than zero." 2008 with ops.control_dependencies( 2009 [check_ops.assert_positive(memory_sequence_length, message=message)]): 2010 memory_mask = array_ops.sequence_mask( 2011 memory_sequence_length, maxlen=array_ops.shape(score)[1]) 2012 score_mask_values = score_mask_value * array_ops.ones_like(score) 2013 return array_ops.where(memory_mask, score, score_mask_values) 2014 2015 2016def hardmax(logits, name=None): 2017 """Returns batched one-hot vectors. 2018 2019 The depth index containing the `1` is that of the maximum logit value. 2020 2021 Args: 2022 logits: A batch tensor of logit values. 2023 name: Name to use when creating ops. 2024 Returns: 2025 A batched one-hot tensor. 2026 """ 2027 with ops.name_scope(name, "Hardmax", [logits]): 2028 logits = ops.convert_to_tensor(logits, name="logits") 2029 if tensor_shape.dimension_value(logits.get_shape()[-1]) is not None: 2030 depth = tensor_shape.dimension_value(logits.get_shape()[-1]) 2031 else: 2032 depth = array_ops.shape(logits)[-1] 2033 return array_ops.one_hot( 2034 math_ops.argmax(logits, -1), depth, dtype=logits.dtype) 2035 2036 2037def _compute_attention(attention_mechanism, cell_output, attention_state, 2038 attention_layer): 2039 """Computes the attention and alignments for a given attention_mechanism.""" 2040 if isinstance(attention_mechanism, _BaseAttentionMechanismV2): 2041 alignments, next_attention_state = attention_mechanism( 2042 [cell_output, attention_state]) 2043 else: 2044 # For other class, assume they are following _BaseAttentionMechanism, which 2045 # takes query and state as separate parameter. 2046 alignments, next_attention_state = attention_mechanism( 2047 cell_output, state=attention_state) 2048 2049 # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] 2050 expanded_alignments = array_ops.expand_dims(alignments, 1) 2051 # Context is the inner product of alignments and values along the 2052 # memory time dimension. 2053 # alignments shape is 2054 # [batch_size, 1, memory_time] 2055 # attention_mechanism.values shape is 2056 # [batch_size, memory_time, memory_size] 2057 # the batched matmul is over memory_time, so the output shape is 2058 # [batch_size, 1, memory_size]. 2059 # we then squeeze out the singleton dim. 2060 context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values) 2061 context_ = array_ops.squeeze(context_, [1]) 2062 2063 if attention_layer is not None: 2064 attention = attention_layer(array_ops.concat([cell_output, context_], 1)) 2065 else: 2066 attention = context_ 2067 2068 return attention, alignments, next_attention_state 2069 2070 2071class AttentionWrapper(rnn_cell_impl.RNNCell): 2072 """Wraps another `RNNCell` with attention. 2073 """ 2074 2075 def __init__(self, 2076 cell, 2077 attention_mechanism, 2078 attention_layer_size=None, 2079 alignment_history=False, 2080 cell_input_fn=None, 2081 output_attention=True, 2082 initial_cell_state=None, 2083 name=None, 2084 attention_layer=None, 2085 attention_fn=None): 2086 """Construct the `AttentionWrapper`. 2087 2088 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in 2089 `AttentionWrapper`, then you must ensure that: 2090 2091 - The encoder output has been tiled to `beam_width` via 2092 `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`). 2093 - The `batch_size` argument passed to the `zero_state` method of this 2094 wrapper is equal to `true_batch_size * beam_width`. 2095 - The initial state created with `zero_state` above contains a 2096 `cell_state` value containing properly tiled final state from the 2097 encoder. 2098 2099 An example: 2100 2101 ``` 2102 tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( 2103 encoder_outputs, multiplier=beam_width) 2104 tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( 2105 encoder_final_state, multiplier=beam_width) 2106 tiled_sequence_length = tf.contrib.seq2seq.tile_batch( 2107 sequence_length, multiplier=beam_width) 2108 attention_mechanism = MyFavoriteAttentionMechanism( 2109 num_units=attention_depth, 2110 memory=tiled_inputs, 2111 memory_sequence_length=tiled_sequence_length) 2112 attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 2113 decoder_initial_state = attention_cell.zero_state( 2114 dtype, batch_size=true_batch_size * beam_width) 2115 decoder_initial_state = decoder_initial_state.clone( 2116 cell_state=tiled_encoder_final_state) 2117 ``` 2118 2119 Args: 2120 cell: An instance of `RNNCell`. 2121 attention_mechanism: A list of `AttentionMechanism` instances or a single 2122 instance. 2123 attention_layer_size: A list of Python integers or a single Python 2124 integer, the depth of the attention (output) layer(s). If None 2125 (default), use the context as attention at each time step. Otherwise, 2126 feed the context and cell output into the attention layer to generate 2127 attention at each time step. If attention_mechanism is a list, 2128 attention_layer_size must be a list of the same length. If 2129 attention_layer is set, this must be None. If attention_fn is set, 2130 it must guaranteed that the outputs of attention_fn also meet the 2131 above requirements. 2132 alignment_history: Python boolean, whether to store alignment history 2133 from all time steps in the final output state (currently stored as a 2134 time major `TensorArray` on which you must call `stack()`). 2135 cell_input_fn: (optional) A `callable`. The default is: 2136 `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. 2137 output_attention: Python bool. If `True` (default), the output at each 2138 time step is the attention value. This is the behavior of Luong-style 2139 attention mechanisms. If `False`, the output at each time step is 2140 the output of `cell`. This is the behavior of Bhadanau-style 2141 attention mechanisms. In both cases, the `attention` tensor is 2142 propagated to the next time step via the state and is used there. 2143 This flag only controls whether the attention mechanism is propagated 2144 up to the next cell in an RNN stack or to the top RNN output. 2145 initial_cell_state: The initial state value to use for the cell when 2146 the user calls `zero_state()`. Note that if this value is provided 2147 now, and the user uses a `batch_size` argument of `zero_state` which 2148 does not match the batch size of `initial_cell_state`, proper 2149 behavior is not guaranteed. 2150 name: Name to use when creating ops. 2151 attention_layer: A list of `tf.layers.Layer` instances or a 2152 single `tf.layers.Layer` instance taking the context and cell output as 2153 inputs to generate attention at each time step. If None (default), use 2154 the context as attention at each time step. If attention_mechanism is a 2155 list, attention_layer must be a list of the same length. If 2156 attention_layers_size is set, this must be None. 2157 attention_fn: An optional callable function that allows users to provide 2158 their own customized attention function, which takes input 2159 (attention_mechanism, cell_output, attention_state, attention_layer) and 2160 outputs (attention, alignments, next_attention_state). If provided, 2161 the attention_layer_size should be the size of the outputs of 2162 attention_fn. 2163 2164 Raises: 2165 TypeError: `attention_layer_size` is not None and (`attention_mechanism` 2166 is a list but `attention_layer_size` is not; or vice versa). 2167 ValueError: if `attention_layer_size` is not None, `attention_mechanism` 2168 is a list, and its length does not match that of `attention_layer_size`; 2169 if `attention_layer_size` and `attention_layer` are set simultaneously. 2170 """ 2171 super(AttentionWrapper, self).__init__(name=name) 2172 rnn_cell_impl.assert_like_rnncell("cell", cell) 2173 if isinstance(attention_mechanism, (list, tuple)): 2174 self._is_multi = True 2175 attention_mechanisms = attention_mechanism 2176 for attention_mechanism in attention_mechanisms: 2177 if not isinstance(attention_mechanism, AttentionMechanism): 2178 raise TypeError( 2179 "attention_mechanism must contain only instances of " 2180 "AttentionMechanism, saw type: %s" 2181 % type(attention_mechanism).__name__) 2182 else: 2183 self._is_multi = False 2184 if not isinstance(attention_mechanism, AttentionMechanism): 2185 raise TypeError( 2186 "attention_mechanism must be an AttentionMechanism or list of " 2187 "multiple AttentionMechanism instances, saw type: %s" 2188 % type(attention_mechanism).__name__) 2189 attention_mechanisms = (attention_mechanism,) 2190 2191 if cell_input_fn is None: 2192 cell_input_fn = ( 2193 lambda inputs, attention: array_ops.concat([inputs, attention], -1)) 2194 else: 2195 if not callable(cell_input_fn): 2196 raise TypeError( 2197 "cell_input_fn must be callable, saw type: %s" 2198 % type(cell_input_fn).__name__) 2199 2200 if attention_layer_size is not None and attention_layer is not None: 2201 raise ValueError("Only one of attention_layer_size and attention_layer " 2202 "should be set") 2203 2204 if attention_layer_size is not None: 2205 attention_layer_sizes = tuple( 2206 attention_layer_size 2207 if isinstance(attention_layer_size, (list, tuple)) 2208 else (attention_layer_size,)) 2209 if len(attention_layer_sizes) != len(attention_mechanisms): 2210 raise ValueError( 2211 "If provided, attention_layer_size must contain exactly one " 2212 "integer per attention_mechanism, saw: %d vs %d" 2213 % (len(attention_layer_sizes), len(attention_mechanisms))) 2214 self._attention_layers = tuple( 2215 layers_core.Dense( 2216 attention_layer_size, 2217 name="attention_layer", 2218 use_bias=False, 2219 dtype=attention_mechanisms[i].dtype) 2220 for i, attention_layer_size in enumerate(attention_layer_sizes)) 2221 self._attention_layer_size = sum(attention_layer_sizes) 2222 elif attention_layer is not None: 2223 self._attention_layers = tuple( 2224 attention_layer 2225 if isinstance(attention_layer, (list, tuple)) 2226 else (attention_layer,)) 2227 if len(self._attention_layers) != len(attention_mechanisms): 2228 raise ValueError( 2229 "If provided, attention_layer must contain exactly one " 2230 "layer per attention_mechanism, saw: %d vs %d" 2231 % (len(self._attention_layers), len(attention_mechanisms))) 2232 self._attention_layer_size = sum( 2233 tensor_shape.dimension_value(layer.compute_output_shape( 2234 [None, 2235 cell.output_size + tensor_shape.dimension_value( 2236 mechanism.values.shape[-1])])[-1]) 2237 for layer, mechanism in zip( 2238 self._attention_layers, attention_mechanisms)) 2239 else: 2240 self._attention_layers = None 2241 self._attention_layer_size = sum( 2242 tensor_shape.dimension_value(attention_mechanism.values.shape[-1]) 2243 for attention_mechanism in attention_mechanisms) 2244 2245 if attention_fn is None: 2246 attention_fn = _compute_attention 2247 self._attention_fn = attention_fn 2248 2249 self._cell = cell 2250 self._attention_mechanisms = attention_mechanisms 2251 self._cell_input_fn = cell_input_fn 2252 self._output_attention = output_attention 2253 self._alignment_history = alignment_history 2254 with ops.name_scope(name, "AttentionWrapperInit"): 2255 if initial_cell_state is None: 2256 self._initial_cell_state = None 2257 else: 2258 final_state_tensor = nest.flatten(initial_cell_state)[-1] 2259 state_batch_size = ( 2260 tensor_shape.dimension_value(final_state_tensor.shape[0]) 2261 or array_ops.shape(final_state_tensor)[0]) 2262 error_message = ( 2263 "When constructing AttentionWrapper %s: " % self._base_name + 2264 "Non-matching batch sizes between the memory " 2265 "(encoder output) and initial_cell_state. Are you using " 2266 "the BeamSearchDecoder? You may need to tile your initial state " 2267 "via the tf.contrib.seq2seq.tile_batch function with argument " 2268 "multiple=beam_width.") 2269 with ops.control_dependencies( 2270 self._batch_size_checks(state_batch_size, error_message)): 2271 self._initial_cell_state = nest.map_structure( 2272 lambda s: array_ops.identity(s, name="check_initial_cell_state"), 2273 initial_cell_state) 2274 2275 def _batch_size_checks(self, batch_size, error_message): 2276 return [check_ops.assert_equal(batch_size, 2277 attention_mechanism.batch_size, 2278 message=error_message) 2279 for attention_mechanism in self._attention_mechanisms] 2280 2281 def _item_or_tuple(self, seq): 2282 """Returns `seq` as tuple or the singular element. 2283 2284 Which is returned is determined by how the AttentionMechanism(s) were passed 2285 to the constructor. 2286 2287 Args: 2288 seq: A non-empty sequence of items or generator. 2289 2290 Returns: 2291 Either the values in the sequence as a tuple if AttentionMechanism(s) 2292 were passed to the constructor as a sequence or the singular element. 2293 """ 2294 t = tuple(seq) 2295 if self._is_multi: 2296 return t 2297 else: 2298 return t[0] 2299 2300 @property 2301 def output_size(self): 2302 if self._output_attention: 2303 return self._attention_layer_size 2304 else: 2305 return self._cell.output_size 2306 2307 @property 2308 def state_size(self): 2309 """The `state_size` property of `AttentionWrapper`. 2310 2311 Returns: 2312 An `AttentionWrapperState` tuple containing shapes used by this object. 2313 """ 2314 return AttentionWrapperState( 2315 cell_state=self._cell.state_size, 2316 time=tensor_shape.TensorShape([]), 2317 attention=self._attention_layer_size, 2318 alignments=self._item_or_tuple( 2319 a.alignments_size for a in self._attention_mechanisms), 2320 attention_state=self._item_or_tuple( 2321 a.state_size for a in self._attention_mechanisms), 2322 alignment_history=self._item_or_tuple( 2323 a.alignments_size if self._alignment_history else () 2324 for a in self._attention_mechanisms)) # sometimes a TensorArray 2325 2326 def zero_state(self, batch_size, dtype): 2327 """Return an initial (zero) state tuple for this `AttentionWrapper`. 2328 2329 **NOTE** Please see the initializer documentation for details of how 2330 to call `zero_state` if using an `AttentionWrapper` with a 2331 `BeamSearchDecoder`. 2332 2333 Args: 2334 batch_size: `0D` integer tensor: the batch size. 2335 dtype: The internal state data type. 2336 2337 Returns: 2338 An `AttentionWrapperState` tuple containing zeroed out tensors and, 2339 possibly, empty `TensorArray` objects. 2340 2341 Raises: 2342 ValueError: (or, possibly at runtime, InvalidArgument), if 2343 `batch_size` does not match the output size of the encoder passed 2344 to the wrapper object at initialization time. 2345 """ 2346 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 2347 if self._initial_cell_state is not None: 2348 cell_state = self._initial_cell_state 2349 else: 2350 cell_state = self._cell.get_initial_state(batch_size=batch_size, 2351 dtype=dtype) 2352 error_message = ( 2353 "When calling zero_state of AttentionWrapper %s: " % self._base_name + 2354 "Non-matching batch sizes between the memory " 2355 "(encoder output) and the requested batch size. Are you using " 2356 "the BeamSearchDecoder? If so, make sure your encoder output has " 2357 "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " 2358 "the batch_size= argument passed to zero_state is " 2359 "batch_size * beam_width.") 2360 with ops.control_dependencies( 2361 self._batch_size_checks(batch_size, error_message)): 2362 cell_state = nest.map_structure( 2363 lambda s: array_ops.identity(s, name="checked_cell_state"), 2364 cell_state) 2365 initial_alignments = [ 2366 attention_mechanism.initial_alignments(batch_size, dtype) 2367 for attention_mechanism in self._attention_mechanisms] 2368 return AttentionWrapperState( 2369 cell_state=cell_state, 2370 time=array_ops.zeros([], dtype=dtypes.int32), 2371 attention=_zero_state_tensors(self._attention_layer_size, batch_size, 2372 dtype), 2373 alignments=self._item_or_tuple(initial_alignments), 2374 attention_state=self._item_or_tuple( 2375 attention_mechanism.initial_state(batch_size, dtype) 2376 for attention_mechanism in self._attention_mechanisms), 2377 alignment_history=self._item_or_tuple( 2378 tensor_array_ops.TensorArray( 2379 dtype, 2380 size=0, 2381 dynamic_size=True, 2382 element_shape=alignment.shape) 2383 if self._alignment_history else () 2384 for alignment in initial_alignments)) 2385 2386 def call(self, inputs, state): 2387 """Perform a step of attention-wrapped RNN. 2388 2389 - Step 1: Mix the `inputs` and previous step's `attention` output via 2390 `cell_input_fn`. 2391 - Step 2: Call the wrapped `cell` with this input and its previous state. 2392 - Step 3: Score the cell's output with `attention_mechanism`. 2393 - Step 4: Calculate the alignments by passing the score through the 2394 `normalizer`. 2395 - Step 5: Calculate the context vector as the inner product between the 2396 alignments and the attention_mechanism's values (memory). 2397 - Step 6: Calculate the attention output by concatenating the cell output 2398 and context through the attention layer (a linear layer with 2399 `attention_layer_size` outputs). 2400 2401 Args: 2402 inputs: (Possibly nested tuple of) Tensor, the input at this time step. 2403 state: An instance of `AttentionWrapperState` containing 2404 tensors from the previous time step. 2405 2406 Returns: 2407 A tuple `(attention_or_cell_output, next_state)`, where: 2408 2409 - `attention_or_cell_output` depending on `output_attention`. 2410 - `next_state` is an instance of `AttentionWrapperState` 2411 containing the state calculated at this time step. 2412 2413 Raises: 2414 TypeError: If `state` is not an instance of `AttentionWrapperState`. 2415 """ 2416 if not isinstance(state, AttentionWrapperState): 2417 raise TypeError("Expected state to be instance of AttentionWrapperState. " 2418 "Received type %s instead." % type(state)) 2419 2420 # Step 1: Calculate the true inputs to the cell based on the 2421 # previous attention value. 2422 cell_inputs = self._cell_input_fn(inputs, state.attention) 2423 cell_state = state.cell_state 2424 cell_output, next_cell_state = self._cell(cell_inputs, cell_state) 2425 2426 cell_batch_size = ( 2427 tensor_shape.dimension_value(cell_output.shape[0]) or 2428 array_ops.shape(cell_output)[0]) 2429 error_message = ( 2430 "When applying AttentionWrapper %s: " % self.name + 2431 "Non-matching batch sizes between the memory " 2432 "(encoder output) and the query (decoder output). Are you using " 2433 "the BeamSearchDecoder? You may need to tile your memory input via " 2434 "the tf.contrib.seq2seq.tile_batch function with argument " 2435 "multiple=beam_width.") 2436 with ops.control_dependencies( 2437 self._batch_size_checks(cell_batch_size, error_message)): 2438 cell_output = array_ops.identity( 2439 cell_output, name="checked_cell_output") 2440 2441 if self._is_multi: 2442 previous_attention_state = state.attention_state 2443 previous_alignment_history = state.alignment_history 2444 else: 2445 previous_attention_state = [state.attention_state] 2446 previous_alignment_history = [state.alignment_history] 2447 2448 all_alignments = [] 2449 all_attentions = [] 2450 all_attention_states = [] 2451 maybe_all_histories = [] 2452 for i, attention_mechanism in enumerate(self._attention_mechanisms): 2453 attention, alignments, next_attention_state = self._attention_fn( 2454 attention_mechanism, cell_output, previous_attention_state[i], 2455 self._attention_layers[i] if self._attention_layers else None) 2456 alignment_history = previous_alignment_history[i].write( 2457 state.time, alignments) if self._alignment_history else () 2458 2459 all_attention_states.append(next_attention_state) 2460 all_alignments.append(alignments) 2461 all_attentions.append(attention) 2462 maybe_all_histories.append(alignment_history) 2463 2464 attention = array_ops.concat(all_attentions, 1) 2465 next_state = AttentionWrapperState( 2466 time=state.time + 1, 2467 cell_state=next_cell_state, 2468 attention=attention, 2469 attention_state=self._item_or_tuple(all_attention_states), 2470 alignments=self._item_or_tuple(all_alignments), 2471 alignment_history=self._item_or_tuple(maybe_all_histories)) 2472 2473 if self._output_attention: 2474 return attention, next_state 2475 else: 2476 return cell_output, next_state 2477