• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A powerful dynamic attention wrapper object."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import math
24
25import numpy as np
26
27from tensorflow.contrib.framework.python.framework import tensor_util
28from tensorflow.python.eager import context
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.keras import initializers
33from tensorflow.python.keras import layers
34from tensorflow.python.keras.engine import base_layer_utils
35from tensorflow.python.layers import base as layers_base
36from tensorflow.python.layers import core as layers_core
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import check_ops
39from tensorflow.python.ops import clip_ops
40from tensorflow.python.ops import functional_ops
41from tensorflow.python.ops import init_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import nn_ops
44from tensorflow.python.ops import random_ops
45from tensorflow.python.ops import rnn_cell_impl
46from tensorflow.python.ops import tensor_array_ops
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.util import nest
49
50
51__all__ = [
52    "AttentionMechanism",
53    "AttentionWrapper",
54    "AttentionWrapperState",
55    "LuongAttention",
56    "BahdanauAttention",
57    "hardmax",
58    "safe_cumprod",
59    "monotonic_attention",
60    "BahdanauMonotonicAttention",
61    "LuongMonotonicAttention",
62]
63
64
65_zero_state_tensors = rnn_cell_impl._zero_state_tensors  # pylint: disable=protected-access
66
67
68class AttentionMechanism(object):
69
70  @property
71  def alignments_size(self):
72    raise NotImplementedError
73
74  @property
75  def state_size(self):
76    raise NotImplementedError
77
78
79class _BaseAttentionMechanism(AttentionMechanism):
80  """A base AttentionMechanism class providing common functionality.
81
82  Common functionality includes:
83    1. Storing the query and memory layers.
84    2. Preprocessing and storing the memory.
85  """
86
87  def __init__(self,
88               query_layer,
89               memory,
90               probability_fn,
91               memory_sequence_length=None,
92               memory_layer=None,
93               check_inner_dims_defined=True,
94               score_mask_value=None,
95               name=None):
96    """Construct base AttentionMechanism class.
97
98    Args:
99      query_layer: Callable.  Instance of `tf.layers.Layer`.  The layer's depth
100        must match the depth of `memory_layer`.  If `query_layer` is not
101        provided, the shape of `query` must match that of `memory_layer`.
102      memory: The memory to query; usually the output of an RNN encoder.  This
103        tensor should be shaped `[batch_size, max_time, ...]`.
104      probability_fn: A `callable`.  Converts the score and previous alignments
105        to probabilities. Its signature should be:
106        `probabilities = probability_fn(score, state)`.
107      memory_sequence_length (optional): Sequence lengths for the batch entries
108        in memory.  If provided, the memory tensor rows are masked with zeros
109        for values past the respective sequence lengths.
110      memory_layer: Instance of `tf.layers.Layer` (may be None).  The layer's
111        depth must match the depth of `query_layer`.
112        If `memory_layer` is not provided, the shape of `memory` must match
113        that of `query_layer`.
114      check_inner_dims_defined: Python boolean.  If `True`, the `memory`
115        argument's shape is checked to ensure all but the two outermost
116        dimensions are fully defined.
117      score_mask_value: (optional): The mask value for score before passing into
118        `probability_fn`. The default is -inf. Only used if
119        `memory_sequence_length` is not None.
120      name: Name to use when creating ops.
121    """
122    if (query_layer is not None
123        and not isinstance(query_layer, layers_base.Layer)):
124      raise TypeError(
125          "query_layer is not a Layer: %s" % type(query_layer).__name__)
126    if (memory_layer is not None
127        and not isinstance(memory_layer, layers_base.Layer)):
128      raise TypeError(
129          "memory_layer is not a Layer: %s" % type(memory_layer).__name__)
130    self._query_layer = query_layer
131    self._memory_layer = memory_layer
132    self.dtype = memory_layer.dtype
133    if not callable(probability_fn):
134      raise TypeError("probability_fn must be callable, saw type: %s" %
135                      type(probability_fn).__name__)
136    if score_mask_value is None:
137      score_mask_value = dtypes.as_dtype(
138          self._memory_layer.dtype).as_numpy_dtype(-np.inf)
139    self._probability_fn = lambda score, prev: (  # pylint:disable=g-long-lambda
140        probability_fn(
141            _maybe_mask_score(score,
142                              memory_sequence_length=memory_sequence_length,
143                              score_mask_value=score_mask_value),
144            prev))
145    with ops.name_scope(
146        name, "BaseAttentionMechanismInit", nest.flatten(memory)):
147      self._values = _prepare_memory(
148          memory, memory_sequence_length=memory_sequence_length,
149          check_inner_dims_defined=check_inner_dims_defined)
150      self._keys = (
151          self.memory_layer(self._values) if self.memory_layer  # pylint: disable=not-callable
152          else self._values)
153      self._batch_size = (
154          tensor_shape.dimension_value(self._keys.shape[0]) or
155          array_ops.shape(self._keys)[0])
156      self._alignments_size = (tensor_shape.dimension_value(self._keys.shape[1])
157                               or array_ops.shape(self._keys)[1])
158
159  @property
160  def memory_layer(self):
161    return self._memory_layer
162
163  @property
164  def query_layer(self):
165    return self._query_layer
166
167  @property
168  def values(self):
169    return self._values
170
171  @property
172  def keys(self):
173    return self._keys
174
175  @property
176  def batch_size(self):
177    return self._batch_size
178
179  @property
180  def alignments_size(self):
181    return self._alignments_size
182
183  @property
184  def state_size(self):
185    return self._alignments_size
186
187  def initial_alignments(self, batch_size, dtype):
188    """Creates the initial alignment values for the `AttentionWrapper` class.
189
190    This is important for AttentionMechanisms that use the previous alignment
191    to calculate the alignment at the next time step (e.g. monotonic attention).
192
193    The default behavior is to return a tensor of all zeros.
194
195    Args:
196      batch_size: `int32` scalar, the batch_size.
197      dtype: The `dtype`.
198
199    Returns:
200      A `dtype` tensor shaped `[batch_size, alignments_size]`
201      (`alignments_size` is the values' `max_time`).
202    """
203    max_time = self._alignments_size
204    return _zero_state_tensors(max_time, batch_size, dtype)
205
206  def initial_state(self, batch_size, dtype):
207    """Creates the initial state values for the `AttentionWrapper` class.
208
209    This is important for AttentionMechanisms that use the previous alignment
210    to calculate the alignment at the next time step (e.g. monotonic attention).
211
212    The default behavior is to return the same output as initial_alignments.
213
214    Args:
215      batch_size: `int32` scalar, the batch_size.
216      dtype: The `dtype`.
217
218    Returns:
219      A structure of all-zero tensors with shapes as described by `state_size`.
220    """
221    return self.initial_alignments(batch_size, dtype)
222
223
224class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer):
225  """A base AttentionMechanism class providing common functionality.
226
227  Common functionality includes:
228    1. Storing the query and memory layers.
229    2. Preprocessing and storing the memory.
230
231  Note that this layer takes memory as its init parameter, which is an
232  anti-pattern of Keras API, we have to keep the memory as init parameter for
233  performance and dependency reason. Under the hood, during `__init__()`, it
234  will invoke `base_layer.__call__(memory, setup_memory=True)`. This will let
235  keras to keep track of the memory tensor as the input of this layer. Once
236  the `__init__()` is done, then user can query the attention by
237  `score = att_obj([query, state])`, and use it as a normal keras layer.
238
239  Special attention is needed when adding using this class as the base layer for
240  new attention:
241    1. Build() could be invoked at least twice. So please make sure weights are
242       not duplicated.
243    2. Layer.get_weights() might return different set of weights if the instance
244       has `query_layer`. The query_layer weights is not initialized until the
245       memory is configured.
246
247  Also note that this layer does not work with Keras model when
248  `model.compile(run_eagerly=True)` due to the fact that this layer is stateful.
249  The support for that will be added in a future version.
250  """
251
252  def __init__(self,
253               memory,
254               probability_fn,
255               query_layer=None,
256               memory_layer=None,
257               memory_sequence_length=None,
258               **kwargs):
259    """Construct base AttentionMechanism class.
260
261    Args:
262      memory: The memory to query; usually the output of an RNN encoder.  This
263        tensor should be shaped `[batch_size, max_time, ...]`.
264      probability_fn: A `callable`. Converts the score and previous alignments
265        to probabilities. Its signature should be:
266        `probabilities = probability_fn(score, state)`.
267      query_layer:  (optional): Instance of `tf.keras.Layer`.  The layer's depth
268        must match the depth of `memory_layer`.  If `query_layer` is not
269        provided, the shape of `query` must match that of `memory_layer`.
270      memory_layer: (optional): Instance of `tf.keras.Layer`. The layer's
271        depth must match the depth of `query_layer`.
272        If `memory_layer` is not provided, the shape of `memory` must match
273        that of `query_layer`.
274      memory_sequence_length (optional): Sequence lengths for the batch entries
275        in memory. If provided, the memory tensor rows are masked with zeros
276        for values past the respective sequence lengths.
277      **kwargs: Dictionary that contains other common arguments for layer
278        creation.
279    """
280    if (query_layer is not None
281        and not isinstance(query_layer, layers.Layer)):
282      raise TypeError(
283          "query_layer is not a Layer: %s" % type(query_layer).__name__)
284    if (memory_layer is not None
285        and not isinstance(memory_layer, layers.Layer)):
286      raise TypeError(
287          "memory_layer is not a Layer: %s" % type(memory_layer).__name__)
288    self.query_layer = query_layer
289    self.memory_layer = memory_layer
290    if self.memory_layer is not None and "dtype" not in kwargs:
291      kwargs["dtype"] = self.memory_layer.dtype
292    super(_BaseAttentionMechanismV2, self).__init__(**kwargs)
293    if not callable(probability_fn):
294      raise TypeError("probability_fn must be callable, saw type: %s" %
295                      type(probability_fn).__name__)
296    self.probability_fn = probability_fn
297
298    self.keys = None
299    self.values = None
300    self.batch_size = None
301    self._memory_initialized = False
302    self._check_inner_dims_defined = True
303    self.supports_masking = True
304    self.score_mask_value = dtypes.as_dtype(self.dtype).as_numpy_dtype(-np.inf)
305
306    if memory is not None:
307      # Setup the memory by self.__call__() with memory and memory_seq_length.
308      # This will make the attention follow the keras convention which takes
309      # all the tensor inputs via __call__().
310      if memory_sequence_length is None:
311        inputs = memory
312      else:
313        inputs = [memory, memory_sequence_length]
314
315      self.values = super(_BaseAttentionMechanismV2, self).__call__(
316          inputs, setup_memory=True)
317
318  def build(self, input_shape):
319    if not self._memory_initialized:
320      # This is for setting up the memory, which contains memory and optional
321      # memory_sequence_length. Build the memory_layer with memory shape.
322      if self.memory_layer is not None and not self.memory_layer.built:
323        if isinstance(input_shape, list):
324          self.memory_layer.build(input_shape[0])
325        else:
326          self.memory_layer.build(input_shape)
327    else:
328      # The input_shape should be query.shape and state.shape. Use the query
329      # to init the query layer.
330      if self.query_layer is not None and not self.query_layer.built:
331        self.query_layer.build(input_shape[0])
332
333  def __call__(self, inputs, **kwargs):
334    """Preprocess the inputs before calling `base_layer.__call__()`.
335
336    Note that there are situation here, one for setup memory, and one with
337    actual query and state.
338    1. When the memory has not been configured, we just pass all the param to
339    base_layer.__call__(), which will then invoke self.call() with proper
340    inputs, which allows this class to setup memory.
341    2. When the memory has already been setup, the input should contain query
342    and state, and optionally processed memory. If the processed memory is
343    not included in the input, we will have to append it to the inputs and
344    give it to the base_layer.__call__(). The processed memory is the output
345    of first invocation of self.__call__(). If we don't add it here, then from
346    keras perspective, the graph is disconnected since the output from
347    previous call is never used.
348
349    Args:
350      inputs: the inputs tensors.
351      **kwargs: dict, other keyeword arguments for the `__call__()`
352    """
353    if self._memory_initialized:
354      if len(inputs) not in (2, 3):
355        raise ValueError("Expect the inputs to have 2 or 3 tensors, got %d" %
356                         len(inputs))
357      if len(inputs) == 2:
358        # We append the calculated memory here so that the graph will be
359        # connected.
360        inputs.append(self.values)
361    return super(_BaseAttentionMechanismV2, self).__call__(inputs, **kwargs)
362
363  def call(self, inputs, mask=None, setup_memory=False, **kwargs):
364    """Setup the memory or query the attention.
365
366    There are two case here, one for setup memory, and the second is query the
367    attention score. `setup_memory` is the flag to indicate which mode it is.
368    The input list will be treated differently based on that flag.
369
370    Args:
371      inputs: a list of tensor that could either be `query` and `state`, or
372        `memory` and `memory_sequence_length`.
373        `query` is the tensor of dtype matching `memory` and shape
374        `[batch_size, query_depth]`.
375        `state` is the tensor of dtype matching `memory` and shape
376        `[batch_size, alignments_size]`. (`alignments_size` is memory's
377        `max_time`).
378        `memory` is the memory to query; usually the output of an RNN encoder.
379        The tensor should be shaped `[batch_size, max_time, ...]`.
380        `memory_sequence_length` (optional) is the sequence lengths for the
381         batch entries in memory. If provided, the memory tensor rows are masked
382        with zeros for values past the respective sequence lengths.
383      mask: optional bool tensor with shape `[batch, max_time]` for the mask of
384        memory. If it is not None, the corresponding item of the memory should
385        be filtered out during calculation.
386      setup_memory: boolean, whether the input is for setting up memory, or
387        query attention.
388      **kwargs: Dict, other keyword arguments for the call method.
389    Returns:
390      Either processed memory or attention score, based on `setup_memory`.
391    """
392    if setup_memory:
393      if isinstance(inputs, list):
394        if len(inputs) not in (1, 2):
395          raise ValueError("Expect inputs to have 1 or 2 tensors, got %d" %
396                           len(inputs))
397        memory = inputs[0]
398        memory_sequence_length = inputs[1] if len(inputs) == 2 else None
399        memory_mask = mask
400      else:
401        memory, memory_sequence_length = inputs, None
402        memory_mask = mask
403      self._setup_memory(memory, memory_sequence_length, memory_mask)
404      # We force the self.built to false here since only memory is initialized,
405      # but the real query/state has not been call() yet. The layer should be
406      # build and call again.
407      self.built = False
408      # Return the processed memory in order to create the Keras connectivity
409      # data for it.
410      return self.values
411    else:
412      if not self._memory_initialized:
413        raise ValueError("Cannot query the attention before the setup of "
414                         "memory")
415      if len(inputs) not in (2, 3):
416        raise ValueError("Expect the inputs to have query, state, and optional "
417                         "processed memory, got %d items" % len(inputs))
418      # Ignore the rest of the inputs and only care about the query and state
419      query, state = inputs[0], inputs[1]
420      return self._calculate_attention(query, state)
421
422  def _setup_memory(self, memory, memory_sequence_length=None,
423                    memory_mask=None):
424    """Pre-process the memory before actually query the memory.
425
426    This should only be called once at the first invocation of call().
427
428    Args:
429      memory: The memory to query; usually the output of an RNN encoder. This
430        tensor should be shaped `[batch_size, max_time, ...]`.
431      memory_sequence_length (optional): Sequence lengths for the batch entries
432        in memory. If provided, the memory tensor rows are masked with zeros for
433        values past the respective sequence lengths.
434      memory_mask: (Optional) The boolean tensor with shape `[batch_size,
435        max_time]`. For any value equal to False, the corresponding value in
436        memory should be ignored.
437    """
438    if self._memory_initialized:
439      raise ValueError("The memory for the attention has already been setup.")
440    if memory_sequence_length is not None and memory_mask is not None:
441      raise ValueError("memory_sequence_length and memory_mask cannot be "
442                       "used at same time for attention.")
443    with ops.name_scope(
444        self.name, "BaseAttentionMechanismInit", nest.flatten(memory)):
445      self.values = _prepare_memory(
446          memory,
447          memory_sequence_length=memory_sequence_length,
448          memory_mask=memory_mask,
449          check_inner_dims_defined=self._check_inner_dims_defined)
450      # Mark the value as check since the memory and memory mask might not
451      # passed from __call__(), which does not have proper keras metadata.
452      # TODO(omalleyt): Remove this hack once the mask the has proper keras
453      # history.
454      base_layer_utils.mark_checked(self.values)
455      if self.memory_layer is not None:
456        self.keys = self.memory_layer(self.values)
457      else:
458        self.keys = self.values
459      self.batch_size = (
460          tensor_shape.dimension_value(self.keys.shape[0]) or
461          array_ops.shape(self.keys)[0])
462      self._alignments_size = (tensor_shape.dimension_value(self.keys.shape[1])
463                               or array_ops.shape(self.keys)[1])
464      if memory_mask is not None:
465        unwrapped_probability_fn = self.probability_fn
466        def _mask_probability_fn(score, prev):
467          return unwrapped_probability_fn(
468              _maybe_mask_score(
469                  score,
470                  memory_mask=memory_mask,
471                  memory_sequence_length=memory_sequence_length,
472                  score_mask_value=self.score_mask_value), prev)
473        self.probability_fn = _mask_probability_fn
474    self._memory_initialized = True
475
476  def _calculate_attention(self, query, state):
477    raise NotImplementedError(
478        "_calculate_attention need to be implemented by subclasses.")
479
480  def compute_mask(self, inputs, mask=None):
481    # There real input of the attention is query and state, and the memory layer
482    # mask shouldn't be pass down. Returning None for all output mask here.
483    return None, None
484
485  def get_config(self):
486    config = {}
487    # Since the probability_fn is likely to be a wrapped function, the child
488    # class should preserve the original function and how its wrapped.
489
490    if self.query_layer is not None:
491      config["query_layer"] = {
492          "class_name": self.query_layer.__class__.__name__,
493          "config": self.query_layer.get_config(),
494      }
495    if self.memory_layer is not None:
496      config["memory_layer"] = {
497          "class_name": self.memory_layer.__class__.__name__,
498          "config": self.memory_layer.get_config(),
499      }
500    # memory is a required init parameter and its a tensor. It cannot be
501    # serialized to config, so we put a placeholder for it.
502    config["memory"] = None
503    base_config = super(_BaseAttentionMechanismV2, self).get_config()
504    return dict(list(base_config.items()) + list(config.items()))
505
506  def _process_probability_fn(self, func_name):
507    """Helper method to retrieve the probably function by string input."""
508    valid_probability_fns = {
509        "softmax": nn_ops.softmax,
510        "hardmax": hardmax,
511    }
512    if func_name not in valid_probability_fns.keys():
513      raise ValueError("Invalid probability function: %s, options are %s" %
514                       (func_name, valid_probability_fns.keys()))
515    return valid_probability_fns[func_name]
516
517  @classmethod
518  def deserialize_inner_layer_from_config(cls, config, custom_objects):
519    """Helper method that reconstruct the query and memory from the config.
520
521    In the get_config() method, the query and memory layer configs are
522    serialized into dict for persistence, this method perform the reverse action
523    to reconstruct the layer from the config.
524
525    Args:
526      config: dict, the configs that will be used to reconstruct the object.
527      custom_objects: dict mapping class names (or function names) of custom
528        (non-Keras) objects to class/functions.
529    Returns:
530      config: dict, the config with layer instance created, which is ready to be
531        used as init parameters.
532    """
533    # Reconstruct the query and memory layer for parent class.
534    from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
535    # Instead of updating the input, create a copy and use that.
536    config = config.copy()
537    query_layer_config = config.pop("query_layer", None)
538    if query_layer_config:
539      query_layer = deserialize_layer(query_layer_config,
540                                      custom_objects=custom_objects)
541      config["query_layer"] = query_layer
542    memory_layer_config = config.pop("memory_layer", None)
543    if memory_layer_config:
544      memory_layer = deserialize_layer(memory_layer_config,
545                                       custom_objects=custom_objects)
546      config["memory_layer"] = memory_layer
547    return config
548
549  @property
550  def alignments_size(self):
551    return self._alignments_size
552
553  @property
554  def state_size(self):
555    return self._alignments_size
556
557  def initial_alignments(self, batch_size, dtype):
558    """Creates the initial alignment values for the `AttentionWrapper` class.
559
560    This is important for AttentionMechanisms that use the previous alignment
561    to calculate the alignment at the next time step (e.g. monotonic attention).
562
563    The default behavior is to return a tensor of all zeros.
564
565    Args:
566      batch_size: `int32` scalar, the batch_size.
567      dtype: The `dtype`.
568
569    Returns:
570      A `dtype` tensor shaped `[batch_size, alignments_size]`
571      (`alignments_size` is the values' `max_time`).
572    """
573    max_time = self._alignments_size
574    return _zero_state_tensors(max_time, batch_size, dtype)
575
576  def initial_state(self, batch_size, dtype):
577    """Creates the initial state values for the `AttentionWrapper` class.
578
579    This is important for AttentionMechanisms that use the previous alignment
580    to calculate the alignment at the next time step (e.g. monotonic attention).
581
582    The default behavior is to return the same output as initial_alignments.
583
584    Args:
585      batch_size: `int32` scalar, the batch_size.
586      dtype: The `dtype`.
587
588    Returns:
589      A structure of all-zero tensors with shapes as described by `state_size`.
590    """
591    return self.initial_alignments(batch_size, dtype)
592
593
594def _luong_score(query, keys, scale):
595  """Implements Luong-style (multiplicative) scoring function.
596
597  This attention has two forms.  The first is standard Luong attention,
598  as described in:
599
600  Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
601  "Effective Approaches to Attention-based Neural Machine Translation."
602  EMNLP 2015.  https://arxiv.org/abs/1508.04025
603
604  The second is the scaled form inspired partly by the normalized form of
605  Bahdanau attention.
606
607  To enable the second form, call this function with `scale=True`.
608
609  Args:
610    query: Tensor, shape `[batch_size, num_units]` to compare to keys.
611    keys: Processed memory, shape `[batch_size, max_time, num_units]`.
612    scale: the optional tensor to scale the attention score.
613
614  Returns:
615    A `[batch_size, max_time]` tensor of unnormalized score values.
616
617  Raises:
618    ValueError: If `key` and `query` depths do not match.
619  """
620  depth = query.get_shape()[-1]
621  key_units = keys.get_shape()[-1]
622  if depth != key_units:
623    raise ValueError(
624        "Incompatible or unknown inner dimensions between query and keys.  "
625        "Query (%s) has units: %s.  Keys (%s) have units: %s.  "
626        "Perhaps you need to set num_units to the keys' dimension (%s)?"
627        % (query, depth, keys, key_units, key_units))
628
629  # Reshape from [batch_size, depth] to [batch_size, 1, depth]
630  # for matmul.
631  query = array_ops.expand_dims(query, 1)
632
633  # Inner product along the query units dimension.
634  # matmul shapes: query is [batch_size, 1, depth] and
635  #                keys is [batch_size, max_time, depth].
636  # the inner product is asked to **transpose keys' inner shape** to get a
637  # batched matmul on:
638  #   [batch_size, 1, depth] . [batch_size, depth, max_time]
639  # resulting in an output shape of:
640  #   [batch_size, 1, max_time].
641  # we then squeeze out the center singleton dimension.
642  score = math_ops.matmul(query, keys, transpose_b=True)
643  score = array_ops.squeeze(score, [1])
644
645  if scale is not None:
646    score = scale * score
647  return score
648
649
650class LuongAttention(_BaseAttentionMechanism):
651  """Implements Luong-style (multiplicative) attention scoring.
652
653  This attention has two forms.  The first is standard Luong attention,
654  as described in:
655
656  Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
657  [Effective Approaches to Attention-based Neural Machine Translation.
658  EMNLP 2015.](https://arxiv.org/abs/1508.04025)
659
660  The second is the scaled form inspired partly by the normalized form of
661  Bahdanau attention.
662
663  To enable the second form, construct the object with parameter
664  `scale=True`.
665  """
666
667  def __init__(self,
668               num_units,
669               memory,
670               memory_sequence_length=None,
671               scale=False,
672               probability_fn=None,
673               score_mask_value=None,
674               dtype=None,
675               name="LuongAttention"):
676    """Construct the AttentionMechanism mechanism.
677
678    Args:
679      num_units: The depth of the attention mechanism.
680      memory: The memory to query; usually the output of an RNN encoder.  This
681        tensor should be shaped `[batch_size, max_time, ...]`.
682      memory_sequence_length: (optional) Sequence lengths for the batch entries
683        in memory.  If provided, the memory tensor rows are masked with zeros
684        for values past the respective sequence lengths.
685      scale: Python boolean.  Whether to scale the energy term.
686      probability_fn: (optional) A `callable`.  Converts the score to
687        probabilities.  The default is `tf.nn.softmax`. Other options include
688        `tf.contrib.seq2seq.hardmax` and `tf.contrib.sparsemax.sparsemax`.
689        Its signature should be: `probabilities = probability_fn(score)`.
690      score_mask_value: (optional) The mask value for score before passing into
691        `probability_fn`. The default is -inf. Only used if
692        `memory_sequence_length` is not None.
693      dtype: The data type for the memory layer of the attention mechanism.
694      name: Name to use when creating ops.
695    """
696    # For LuongAttention, we only transform the memory layer; thus
697    # num_units **must** match expected the query depth.
698    if probability_fn is None:
699      probability_fn = nn_ops.softmax
700    if dtype is None:
701      dtype = dtypes.float32
702    wrapped_probability_fn = lambda score, _: probability_fn(score)
703    super(LuongAttention, self).__init__(
704        query_layer=None,
705        memory_layer=layers_core.Dense(
706            num_units, name="memory_layer", use_bias=False, dtype=dtype),
707        memory=memory,
708        probability_fn=wrapped_probability_fn,
709        memory_sequence_length=memory_sequence_length,
710        score_mask_value=score_mask_value,
711        name=name)
712    self._num_units = num_units
713    self._scale = scale
714    self._name = name
715
716  def __call__(self, query, state):
717    """Score the query based on the keys and values.
718
719    Args:
720      query: Tensor of dtype matching `self.values` and shape
721        `[batch_size, query_depth]`.
722      state: Tensor of dtype matching `self.values` and shape
723        `[batch_size, alignments_size]`
724        (`alignments_size` is memory's `max_time`).
725
726    Returns:
727      alignments: Tensor of dtype matching `self.values` and shape
728        `[batch_size, alignments_size]` (`alignments_size` is memory's
729        `max_time`).
730    """
731    with variable_scope.variable_scope(None, "luong_attention", [query]):
732      attention_g = None
733      if self._scale:
734        attention_g = variable_scope.get_variable(
735            "attention_g", dtype=query.dtype,
736            initializer=init_ops.ones_initializer, shape=())
737      score = _luong_score(query, self._keys, attention_g)
738    alignments = self._probability_fn(score, state)
739    next_state = alignments
740    return alignments, next_state
741
742
743class LuongAttentionV2(_BaseAttentionMechanismV2):
744  """Implements Luong-style (multiplicative) attention scoring.
745
746  This attention has two forms.  The first is standard Luong attention,
747  as described in:
748
749  Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
750  [Effective Approaches to Attention-based Neural Machine Translation.
751  EMNLP 2015.](https://arxiv.org/abs/1508.04025)
752
753  The second is the scaled form inspired partly by the normalized form of
754  Bahdanau attention.
755
756  To enable the second form, construct the object with parameter
757  `scale=True`.
758  """
759
760  def __init__(self,
761               units,
762               memory,
763               memory_sequence_length=None,
764               scale=False,
765               probability_fn="softmax",
766               dtype=None,
767               name="LuongAttention",
768               **kwargs):
769    """Construct the AttentionMechanism mechanism.
770
771    Args:
772      units: The depth of the attention mechanism.
773      memory: The memory to query; usually the output of an RNN encoder.  This
774        tensor should be shaped `[batch_size, max_time, ...]`.
775      memory_sequence_length: (optional): Sequence lengths for the batch entries
776        in memory.  If provided, the memory tensor rows are masked with zeros
777        for values past the respective sequence lengths.
778      scale: Python boolean. Whether to scale the energy term.
779      probability_fn: (optional) string, the name of function to convert the
780        attention score to probabilities. The default is `softmax` which is
781        `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within
782        this module. Any other value will result intovalidation error. Default
783        to use `softmax`.
784      dtype: The data type for the memory layer of the attention mechanism.
785      name: Name to use when creating ops.
786      **kwargs: Dictionary that contains other common arguments for layer
787        creation.
788    """
789    # For LuongAttention, we only transform the memory layer; thus
790    # num_units **must** match expected the query depth.
791    self.probability_fn_name = probability_fn
792    probability_fn = self._process_probability_fn(self.probability_fn_name)
793    wrapped_probability_fn = lambda score, _: probability_fn(score)
794    if dtype is None:
795      dtype = dtypes.float32
796    memory_layer = kwargs.pop("memory_layer", None)
797    if not memory_layer:
798      memory_layer = layers.Dense(
799          units, name="memory_layer", use_bias=False, dtype=dtype)
800    self.units = units
801    self.scale = scale
802    self.scale_weight = None
803    super(LuongAttentionV2, self).__init__(
804        memory=memory,
805        memory_sequence_length=memory_sequence_length,
806        query_layer=None,
807        memory_layer=memory_layer,
808        probability_fn=wrapped_probability_fn,
809        name=name,
810        dtype=dtype,
811        **kwargs)
812
813  def build(self, input_shape):
814    super(LuongAttentionV2, self).build(input_shape)
815    if self.scale and self.scale_weight is None:
816      self.scale_weight = self.add_weight(
817          "attention_g", initializer=init_ops.ones_initializer, shape=())
818    self.built = True
819
820  def _calculate_attention(self, query, state):
821    """Score the query based on the keys and values.
822
823    Args:
824      query: Tensor of dtype matching `self.values` and shape
825        `[batch_size, query_depth]`.
826      state: Tensor of dtype matching `self.values` and shape
827        `[batch_size, alignments_size]`
828        (`alignments_size` is memory's `max_time`).
829
830    Returns:
831      alignments: Tensor of dtype matching `self.values` and shape
832        `[batch_size, alignments_size]` (`alignments_size` is memory's
833        `max_time`).
834      next_state: Same as the alignments.
835    """
836    score = _luong_score(query, self.keys, self.scale_weight)
837    alignments = self.probability_fn(score, state)
838    next_state = alignments
839    return alignments, next_state
840
841  def get_config(self):
842    config = {
843        "units": self.units,
844        "scale": self.scale,
845        "probability_fn": self.probability_fn_name,
846    }
847    base_config = super(LuongAttentionV2, self).get_config()
848    return dict(list(base_config.items()) + list(config.items()))
849
850  @classmethod
851  def from_config(cls, config, custom_objects=None):
852    config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config(
853        config, custom_objects=custom_objects)
854    return cls(**config)
855
856
857def _bahdanau_score(processed_query, keys, attention_v,
858                    attention_g=None, attention_b=None):
859  """Implements Bahdanau-style (additive) scoring function.
860
861  This attention has two forms.  The first is Bhandanau attention,
862  as described in:
863
864  Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
865  "Neural Machine Translation by Jointly Learning to Align and Translate."
866  ICLR 2015. https://arxiv.org/abs/1409.0473
867
868  The second is the normalized form.  This form is inspired by the
869  weight normalization article:
870
871  Tim Salimans, Diederik P. Kingma.
872  "Weight Normalization: A Simple Reparameterization to Accelerate
873   Training of Deep Neural Networks."
874  https://arxiv.org/abs/1602.07868
875
876  To enable the second form, set please pass in attention_g and attention_b.
877
878  Args:
879    processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys.
880    keys: Processed memory, shape `[batch_size, max_time, num_units]`.
881    attention_v: Tensor, shape `[num_units]`.
882    attention_g: Optional scalar tensor for normalization.
883    attention_b: Optional tensor with shape `[num_units]` for normalization.
884
885  Returns:
886    A `[batch_size, max_time]` tensor of unnormalized score values.
887  """
888  # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
889  processed_query = array_ops.expand_dims(processed_query, 1)
890  if attention_g is not None and attention_b is not None:
891    normed_v = attention_g * attention_v * math_ops.rsqrt(
892        math_ops.reduce_sum(math_ops.square(attention_v)))
893    return math_ops.reduce_sum(
894        normed_v * math_ops.tanh(keys + processed_query + attention_b), [2])
895  else:
896    return math_ops.reduce_sum(
897        attention_v * math_ops.tanh(keys + processed_query), [2])
898
899
900class BahdanauAttention(_BaseAttentionMechanism):
901  """Implements Bahdanau-style (additive) attention.
902
903  This attention has two forms.  The first is Bahdanau attention,
904  as described in:
905
906  Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
907  "Neural Machine Translation by Jointly Learning to Align and Translate."
908  ICLR 2015. https://arxiv.org/abs/1409.0473
909
910  The second is the normalized form.  This form is inspired by the
911  weight normalization article:
912
913  Tim Salimans, Diederik P. Kingma.
914  "Weight Normalization: A Simple Reparameterization to Accelerate
915   Training of Deep Neural Networks."
916  https://arxiv.org/abs/1602.07868
917
918  To enable the second form, construct the object with parameter
919  `normalize=True`.
920  """
921
922  def __init__(self,
923               num_units,
924               memory,
925               memory_sequence_length=None,
926               normalize=False,
927               probability_fn=None,
928               score_mask_value=None,
929               dtype=None,
930               name="BahdanauAttention"):
931    """Construct the Attention mechanism.
932
933    Args:
934      num_units: The depth of the query mechanism.
935      memory: The memory to query; usually the output of an RNN encoder.  This
936        tensor should be shaped `[batch_size, max_time, ...]`.
937      memory_sequence_length (optional): Sequence lengths for the batch entries
938        in memory.  If provided, the memory tensor rows are masked with zeros
939        for values past the respective sequence lengths.
940      normalize: Python boolean.  Whether to normalize the energy term.
941      probability_fn: (optional) A `callable`.  Converts the score to
942        probabilities.  The default is `tf.nn.softmax`. Other options include
943        `tf.contrib.seq2seq.hardmax` and `tf.contrib.sparsemax.sparsemax`.
944        Its signature should be: `probabilities = probability_fn(score)`.
945      score_mask_value: (optional): The mask value for score before passing into
946        `probability_fn`. The default is -inf. Only used if
947        `memory_sequence_length` is not None.
948      dtype: The data type for the query and memory layers of the attention
949        mechanism.
950      name: Name to use when creating ops.
951    """
952    if probability_fn is None:
953      probability_fn = nn_ops.softmax
954    if dtype is None:
955      dtype = dtypes.float32
956    wrapped_probability_fn = lambda score, _: probability_fn(score)
957    super(BahdanauAttention, self).__init__(
958        query_layer=layers_core.Dense(
959            num_units, name="query_layer", use_bias=False, dtype=dtype),
960        memory_layer=layers_core.Dense(
961            num_units, name="memory_layer", use_bias=False, dtype=dtype),
962        memory=memory,
963        probability_fn=wrapped_probability_fn,
964        memory_sequence_length=memory_sequence_length,
965        score_mask_value=score_mask_value,
966        name=name)
967    self._num_units = num_units
968    self._normalize = normalize
969    self._name = name
970
971  def __call__(self, query, state):
972    """Score the query based on the keys and values.
973
974    Args:
975      query: Tensor of dtype matching `self.values` and shape
976        `[batch_size, query_depth]`.
977      state: Tensor of dtype matching `self.values` and shape
978        `[batch_size, alignments_size]`
979        (`alignments_size` is memory's `max_time`).
980
981    Returns:
982      alignments: Tensor of dtype matching `self.values` and shape
983        `[batch_size, alignments_size]` (`alignments_size` is memory's
984        `max_time`).
985    """
986    with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
987      processed_query = self.query_layer(query) if self.query_layer else query
988      attention_v = variable_scope.get_variable(
989          "attention_v", [self._num_units], dtype=query.dtype)
990      if not self._normalize:
991        attention_g = None
992        attention_b = None
993      else:
994        attention_g = variable_scope.get_variable(
995            "attention_g", dtype=query.dtype,
996            initializer=init_ops.constant_initializer(
997                math.sqrt((1. / self._num_units))),
998            shape=())
999        attention_b = variable_scope.get_variable(
1000            "attention_b", [self._num_units], dtype=query.dtype,
1001            initializer=init_ops.zeros_initializer())
1002
1003      score = _bahdanau_score(processed_query, self._keys, attention_v,
1004                              attention_g=attention_g, attention_b=attention_b)
1005    alignments = self._probability_fn(score, state)
1006    next_state = alignments
1007    return alignments, next_state
1008
1009
1010class BahdanauAttentionV2(_BaseAttentionMechanismV2):
1011  """Implements Bahdanau-style (additive) attention.
1012
1013  This attention has two forms.  The first is Bahdanau attention,
1014  as described in:
1015
1016  Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
1017  "Neural Machine Translation by Jointly Learning to Align and Translate."
1018  ICLR 2015. https://arxiv.org/abs/1409.0473
1019
1020  The second is the normalized form.  This form is inspired by the
1021  weight normalization article:
1022
1023  Tim Salimans, Diederik P. Kingma.
1024  "Weight Normalization: A Simple Reparameterization to Accelerate
1025   Training of Deep Neural Networks."
1026  https://arxiv.org/abs/1602.07868
1027
1028  To enable the second form, construct the object with parameter
1029  `normalize=True`.
1030  """
1031
1032  def __init__(self,
1033               units,
1034               memory,
1035               memory_sequence_length=None,
1036               normalize=False,
1037               probability_fn="softmax",
1038               kernel_initializer="glorot_uniform",
1039               dtype=None,
1040               name="BahdanauAttention",
1041               **kwargs):
1042    """Construct the Attention mechanism.
1043
1044    Args:
1045      units: The depth of the query mechanism.
1046      memory: The memory to query; usually the output of an RNN encoder.  This
1047        tensor should be shaped `[batch_size, max_time, ...]`.
1048      memory_sequence_length: (optional): Sequence lengths for the batch entries
1049        in memory.  If provided, the memory tensor rows are masked with zeros
1050        for values past the respective sequence lengths.
1051      normalize: Python boolean.  Whether to normalize the energy term.
1052      probability_fn: (optional) string, the name of function to convert the
1053        attention score to probabilities. The default is `softmax` which is
1054        `tf.nn.softmax`. Other options is `hardmax`, which is hardmax() within
1055        this module. Any other value will result into validation error. Default
1056        to use `softmax`.
1057      kernel_initializer: (optional), the name of the initializer for the
1058        attention kernel.
1059      dtype: The data type for the query and memory layers of the attention
1060        mechanism.
1061      name: Name to use when creating ops.
1062      **kwargs: Dictionary that contains other common arguments for layer
1063        creation.
1064    """
1065    self.probability_fn_name = probability_fn
1066    probability_fn = self._process_probability_fn(self.probability_fn_name)
1067    wrapped_probability_fn = lambda score, _: probability_fn(score)
1068    if dtype is None:
1069      dtype = dtypes.float32
1070    query_layer = kwargs.pop("query_layer", None)
1071    if not query_layer:
1072      query_layer = layers.Dense(
1073          units, name="query_layer", use_bias=False, dtype=dtype)
1074    memory_layer = kwargs.pop("memory_layer", None)
1075    if not memory_layer:
1076      memory_layer = layers.Dense(
1077          units, name="memory_layer", use_bias=False, dtype=dtype)
1078    self.units = units
1079    self.normalize = normalize
1080    self.kernel_initializer = initializers.get(kernel_initializer)
1081    self.attention_v = None
1082    self.attention_g = None
1083    self.attention_b = None
1084    super(BahdanauAttentionV2, self).__init__(
1085        memory=memory,
1086        memory_sequence_length=memory_sequence_length,
1087        query_layer=query_layer,
1088        memory_layer=memory_layer,
1089        probability_fn=wrapped_probability_fn,
1090        name=name,
1091        dtype=dtype,
1092        **kwargs)
1093
1094  def build(self, input_shape):
1095    super(BahdanauAttentionV2, self).build(input_shape)
1096    if self.attention_v is None:
1097      self.attention_v = self.add_weight(
1098          "attention_v", [self.units],
1099          dtype=self.dtype,
1100          initializer=self.kernel_initializer)
1101    if self.normalize and self.attention_g is None and self.attention_b is None:
1102      self.attention_g = self.add_weight(
1103          "attention_g", initializer=init_ops.constant_initializer(
1104              math.sqrt((1. / self.units))), shape=())
1105      self.attention_b = self.add_weight(
1106          "attention_b", shape=[self.units],
1107          initializer=init_ops.zeros_initializer())
1108    self.built = True
1109
1110  def _calculate_attention(self, query, state):
1111    """Score the query based on the keys and values.
1112
1113    Args:
1114      query: Tensor of dtype matching `self.values` and shape
1115        `[batch_size, query_depth]`.
1116      state: Tensor of dtype matching `self.values` and shape
1117        `[batch_size, alignments_size]`
1118        (`alignments_size` is memory's `max_time`).
1119
1120    Returns:
1121      alignments: Tensor of dtype matching `self.values` and shape
1122        `[batch_size, alignments_size]` (`alignments_size` is memory's
1123        `max_time`).
1124      next_state: same as alignments.
1125    """
1126    processed_query = self.query_layer(query) if self.query_layer else query
1127    score = _bahdanau_score(processed_query, self.keys, self.attention_v,
1128                            attention_g=self.attention_g,
1129                            attention_b=self.attention_b)
1130    alignments = self.probability_fn(score, state)
1131    next_state = alignments
1132    return alignments, next_state
1133
1134  def get_config(self):
1135    config = {
1136        "units": self.units,
1137        "normalize": self.normalize,
1138        "probability_fn": self.probability_fn_name,
1139        "kernel_initializer": initializers.serialize(self.kernel_initializer)
1140    }
1141    base_config = super(BahdanauAttentionV2, self).get_config()
1142    return dict(list(base_config.items()) + list(config.items()))
1143
1144  @classmethod
1145  def from_config(cls, config, custom_objects=None):
1146    config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config(
1147        config, custom_objects=custom_objects)
1148    return cls(**config)
1149
1150
1151def safe_cumprod(x, *args, **kwargs):
1152  """Computes cumprod of x in logspace using cumsum to avoid underflow.
1153
1154  The cumprod function and its gradient can result in numerical instabilities
1155  when its argument has very small and/or zero values.  As long as the argument
1156  is all positive, we can instead compute the cumulative product as
1157  exp(cumsum(log(x))).  This function can be called identically to tf.cumprod.
1158
1159  Args:
1160    x: Tensor to take the cumulative product of.
1161    *args: Passed on to cumsum; these are identical to those in cumprod.
1162    **kwargs: Passed on to cumsum; these are identical to those in cumprod.
1163  Returns:
1164    Cumulative product of x.
1165  """
1166  with ops.name_scope(None, "SafeCumprod", [x]):
1167    x = ops.convert_to_tensor(x, name="x")
1168    tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
1169    return math_ops.exp(math_ops.cumsum(
1170        math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs))
1171
1172
1173def monotonic_attention(p_choose_i, previous_attention, mode):
1174  """Compute monotonic attention distribution from choosing probabilities.
1175
1176  Monotonic attention implies that the input sequence is processed in an
1177  explicitly left-to-right manner when generating the output sequence.  In
1178  addition, once an input sequence element is attended to at a given output
1179  timestep, elements occurring before it cannot be attended to at subsequent
1180  output timesteps.  This function generates attention distributions according
1181  to these assumptions.  For more information, see `Online and Linear-Time
1182  Attention by Enforcing Monotonic Alignments`.
1183
1184  Args:
1185    p_choose_i: Probability of choosing input sequence/memory element i.  Should
1186      be of shape (batch_size, input_sequence_length), and should all be in the
1187      range [0, 1].
1188    previous_attention: The attention distribution from the previous output
1189      timestep.  Should be of shape (batch_size, input_sequence_length).  For
1190      the first output timestep, preevious_attention[n] should be [1, 0, 0, ...,
1191      0] for all n in [0, ... batch_size - 1].
1192    mode: How to compute the attention distribution.  Must be one of
1193      'recursive', 'parallel', or 'hard'.
1194        * 'recursive' uses tf.scan to recursively compute the distribution.
1195          This is slowest but is exact, general, and does not suffer from
1196          numerical instabilities.
1197        * 'parallel' uses parallelized cumulative-sum and cumulative-product
1198          operations to compute a closed-form solution to the recurrence
1199          relation defining the attention distribution.  This makes it more
1200          efficient than 'recursive', but it requires numerical checks which
1201          make the distribution non-exact.  This can be a problem in particular
1202          when input_sequence_length is long and/or p_choose_i has entries very
1203          close to 0 or 1.
1204        * 'hard' requires that the probabilities in p_choose_i are all either 0
1205          or 1, and subsequently uses a more efficient and exact solution.
1206
1207  Returns:
1208    A tensor of shape (batch_size, input_sequence_length) representing the
1209    attention distributions for each sequence in the batch.
1210
1211  Raises:
1212    ValueError: mode is not one of 'recursive', 'parallel', 'hard'.
1213  """
1214  # Force things to be tensors
1215  p_choose_i = ops.convert_to_tensor(p_choose_i, name="p_choose_i")
1216  previous_attention = ops.convert_to_tensor(
1217      previous_attention, name="previous_attention")
1218  if mode == "recursive":
1219    # Use .shape[0] when it's not None, or fall back on symbolic shape
1220    batch_size = tensor_shape.dimension_value(
1221        p_choose_i.shape[0]) or array_ops.shape(p_choose_i)[0]
1222    # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]]
1223    shifted_1mp_choose_i = array_ops.concat(
1224        [array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1)
1225    # Compute attention distribution recursively as
1226    # q[i] = (1 - p_choose_i[i - 1])*q[i - 1] + previous_attention[i]
1227    # attention[i] = p_choose_i[i]*q[i]
1228    attention = p_choose_i*array_ops.transpose(functional_ops.scan(
1229        # Need to use reshape to remind TF of the shape between loop iterations
1230        lambda x, yz: array_ops.reshape(yz[0]*x + yz[1], (batch_size,)),
1231        # Loop variables yz[0] and yz[1]
1232        [array_ops.transpose(shifted_1mp_choose_i),
1233         array_ops.transpose(previous_attention)],
1234        # Initial value of x is just zeros
1235        array_ops.zeros((batch_size,))))
1236  elif mode == "parallel":
1237    # safe_cumprod computes cumprod in logspace with numeric checks
1238    cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True)
1239    # Compute recurrence relation solution
1240    attention = p_choose_i*cumprod_1mp_choose_i*math_ops.cumsum(
1241        previous_attention /
1242        # Clip cumprod_1mp to avoid divide-by-zero
1243        clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), axis=1)
1244  elif mode == "hard":
1245    # Remove any probabilities before the index chosen last time step
1246    p_choose_i *= math_ops.cumsum(previous_attention, axis=1)
1247    # Now, use exclusive cumprod to remove probabilities after the first
1248    # chosen index, like so:
1249    # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1]
1250    # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0]
1251    # Product of above: [0, 0, 0, 1, 0, 0, 0, 0]
1252    attention = p_choose_i*math_ops.cumprod(
1253        1 - p_choose_i, axis=1, exclusive=True)
1254  else:
1255    raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.")
1256  return attention
1257
1258
1259def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
1260                              seed=None):
1261  """Attention probability function for monotonic attention.
1262
1263  Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage
1264  the model to make discrete attention decisions, passes them through a sigmoid
1265  to obtain "choosing" probabilities, and then calls monotonic_attention to
1266  obtain the attention distribution.  For more information, see
1267
1268  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1269  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1270  ICML 2017.  https://arxiv.org/abs/1704.00784
1271
1272  Args:
1273    score: Unnormalized attention scores, shape `[batch_size, alignments_size]`
1274    previous_alignments: Previous attention distribution, shape
1275      `[batch_size, alignments_size]`
1276    sigmoid_noise: Standard deviation of pre-sigmoid noise.  Setting this larger
1277      than 0 will encourage the model to produce large attention scores,
1278      effectively making the choosing probabilities discrete and the resulting
1279      attention distribution one-hot.  It should be set to 0 at test-time, and
1280      when hard attention is not desired.
1281    mode: How to compute the attention distribution.  Must be one of
1282      'recursive', 'parallel', or 'hard'.  See the docstring for
1283      `tf.contrib.seq2seq.monotonic_attention` for more information.
1284    seed: (optional) Random seed for pre-sigmoid noise.
1285
1286  Returns:
1287    A `[batch_size, alignments_size]`-shape tensor corresponding to the
1288    resulting attention distribution.
1289  """
1290  # Optionally add pre-sigmoid noise to the scores
1291  if sigmoid_noise > 0:
1292    noise = random_ops.random_normal(array_ops.shape(score), dtype=score.dtype,
1293                                     seed=seed)
1294    score += sigmoid_noise*noise
1295  # Compute "choosing" probabilities from the attention scores
1296  if mode == "hard":
1297    # When mode is hard, use a hard sigmoid
1298    p_choose_i = math_ops.cast(score > 0, score.dtype)
1299  else:
1300    p_choose_i = math_ops.sigmoid(score)
1301  # Convert from choosing probabilities to attention distribution
1302  return monotonic_attention(p_choose_i, previous_alignments, mode)
1303
1304
1305class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism):
1306  """Base attention mechanism for monotonic attention.
1307
1308  Simply overrides the initial_alignments function to provide a dirac
1309  distribution, which is needed in order for the monotonic attention
1310  distributions to have the correct behavior.
1311  """
1312
1313  def initial_alignments(self, batch_size, dtype):
1314    """Creates the initial alignment values for the monotonic attentions.
1315
1316    Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0]
1317    for all entries in the batch.
1318
1319    Args:
1320      batch_size: `int32` scalar, the batch_size.
1321      dtype: The `dtype`.
1322
1323    Returns:
1324      A `dtype` tensor shaped `[batch_size, alignments_size]`
1325      (`alignments_size` is the values' `max_time`).
1326    """
1327    max_time = self._alignments_size
1328    return array_ops.one_hot(
1329        array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time,
1330        dtype=dtype)
1331
1332
1333class _BaseMonotonicAttentionMechanismV2(_BaseAttentionMechanismV2):
1334  """Base attention mechanism for monotonic attention.
1335
1336  Simply overrides the initial_alignments function to provide a dirac
1337  distribution, which is needed in order for the monotonic attention
1338  distributions to have the correct behavior.
1339  """
1340
1341  def initial_alignments(self, batch_size, dtype):
1342    """Creates the initial alignment values for the monotonic attentions.
1343
1344    Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0]
1345    for all entries in the batch.
1346
1347    Args:
1348      batch_size: `int32` scalar, the batch_size.
1349      dtype: The `dtype`.
1350
1351    Returns:
1352      A `dtype` tensor shaped `[batch_size, alignments_size]`
1353      (`alignments_size` is the values' `max_time`).
1354    """
1355    max_time = self._alignments_size
1356    return array_ops.one_hot(
1357        array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time,
1358        dtype=dtype)
1359
1360
1361class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
1362  """Monotonic attention mechanism with Bahadanau-style energy function.
1363
1364  This type of attention enforces a monotonic constraint on the attention
1365  distributions; that is once the model attends to a given point in the memory
1366  it can't attend to any prior points at subsequence output timesteps.  It
1367  achieves this by using the _monotonic_probability_fn instead of softmax to
1368  construct its attention distributions.  Since the attention scores are passed
1369  through a sigmoid, a learnable scalar bias parameter is applied after the
1370  score function and before the sigmoid.  Otherwise, it is equivalent to
1371  BahdanauAttention.  This approach is proposed in
1372
1373  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1374  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1375  ICML 2017.  https://arxiv.org/abs/1704.00784
1376  """
1377
1378  def __init__(self,
1379               num_units,
1380               memory,
1381               memory_sequence_length=None,
1382               normalize=False,
1383               score_mask_value=None,
1384               sigmoid_noise=0.,
1385               sigmoid_noise_seed=None,
1386               score_bias_init=0.,
1387               mode="parallel",
1388               dtype=None,
1389               name="BahdanauMonotonicAttention"):
1390    """Construct the Attention mechanism.
1391
1392    Args:
1393      num_units: The depth of the query mechanism.
1394      memory: The memory to query; usually the output of an RNN encoder.  This
1395        tensor should be shaped `[batch_size, max_time, ...]`.
1396      memory_sequence_length (optional): Sequence lengths for the batch entries
1397        in memory.  If provided, the memory tensor rows are masked with zeros
1398        for values past the respective sequence lengths.
1399      normalize: Python boolean.  Whether to normalize the energy term.
1400      score_mask_value: (optional): The mask value for score before passing into
1401        `probability_fn`. The default is -inf. Only used if
1402        `memory_sequence_length` is not None.
1403      sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
1404        for `_monotonic_probability_fn` for more information.
1405      sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
1406      score_bias_init: Initial value for score bias scalar.  It's recommended to
1407        initialize this to a negative value when the length of the memory is
1408        large.
1409      mode: How to compute the attention distribution.  Must be one of
1410        'recursive', 'parallel', or 'hard'.  See the docstring for
1411        `tf.contrib.seq2seq.monotonic_attention` for more information.
1412      dtype: The data type for the query and memory layers of the attention
1413        mechanism.
1414      name: Name to use when creating ops.
1415    """
1416    # Set up the monotonic probability fn with supplied parameters
1417    if dtype is None:
1418      dtype = dtypes.float32
1419    wrapped_probability_fn = functools.partial(
1420        _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
1421        seed=sigmoid_noise_seed)
1422    super(BahdanauMonotonicAttention, self).__init__(
1423        query_layer=layers_core.Dense(
1424            num_units, name="query_layer", use_bias=False, dtype=dtype),
1425        memory_layer=layers_core.Dense(
1426            num_units, name="memory_layer", use_bias=False, dtype=dtype),
1427        memory=memory,
1428        probability_fn=wrapped_probability_fn,
1429        memory_sequence_length=memory_sequence_length,
1430        score_mask_value=score_mask_value,
1431        name=name)
1432    self._num_units = num_units
1433    self._normalize = normalize
1434    self._name = name
1435    self._score_bias_init = score_bias_init
1436
1437  def __call__(self, query, state):
1438    """Score the query based on the keys and values.
1439
1440    Args:
1441      query: Tensor of dtype matching `self.values` and shape
1442        `[batch_size, query_depth]`.
1443      state: Tensor of dtype matching `self.values` and shape
1444        `[batch_size, alignments_size]`
1445        (`alignments_size` is memory's `max_time`).
1446
1447    Returns:
1448      alignments: Tensor of dtype matching `self.values` and shape
1449        `[batch_size, alignments_size]` (`alignments_size` is memory's
1450        `max_time`).
1451    """
1452    with variable_scope.variable_scope(
1453        None, "bahdanau_monotonic_attention", [query]):
1454      processed_query = self.query_layer(query) if self.query_layer else query
1455      attention_v = variable_scope.get_variable(
1456          "attention_v", [self._num_units], dtype=query.dtype)
1457      if not self._normalize:
1458        attention_g = None
1459        attention_b = None
1460      else:
1461        attention_g = variable_scope.get_variable(
1462            "attention_g", dtype=query.dtype,
1463            initializer=init_ops.constant_initializer(
1464                math.sqrt((1. / self._num_units))),
1465            shape=())
1466        attention_b = variable_scope.get_variable(
1467            "attention_b", [self._num_units], dtype=query.dtype,
1468            initializer=init_ops.zeros_initializer())
1469      score = _bahdanau_score(processed_query, self._keys, attention_v,
1470                              attention_g=attention_g, attention_b=attention_b)
1471      score_bias = variable_scope.get_variable(
1472          "attention_score_bias", dtype=processed_query.dtype,
1473          initializer=self._score_bias_init)
1474      score += score_bias
1475    alignments = self._probability_fn(score, state)
1476    next_state = alignments
1477    return alignments, next_state
1478
1479
1480class BahdanauMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2):
1481  """Monotonic attention mechanism with Bahadanau-style energy function.
1482
1483  This type of attention enforces a monotonic constraint on the attention
1484  distributions; that is once the model attends to a given point in the memory
1485  it can't attend to any prior points at subsequence output timesteps.  It
1486  achieves this by using the _monotonic_probability_fn instead of softmax to
1487  construct its attention distributions.  Since the attention scores are passed
1488  through a sigmoid, a learnable scalar bias parameter is applied after the
1489  score function and before the sigmoid.  Otherwise, it is equivalent to
1490  BahdanauAttention.  This approach is proposed in
1491
1492  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1493  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1494  ICML 2017.  https://arxiv.org/abs/1704.00784
1495  """
1496
1497  def __init__(self,
1498               units,
1499               memory,
1500               memory_sequence_length=None,
1501               normalize=False,
1502               sigmoid_noise=0.,
1503               sigmoid_noise_seed=None,
1504               score_bias_init=0.,
1505               mode="parallel",
1506               kernel_initializer="glorot_uniform",
1507               dtype=None,
1508               name="BahdanauMonotonicAttention",
1509               **kwargs):
1510    """Construct the Attention mechanism.
1511
1512    Args:
1513      units: The depth of the query mechanism.
1514      memory: The memory to query; usually the output of an RNN encoder.  This
1515        tensor should be shaped `[batch_size, max_time, ...]`.
1516      memory_sequence_length: (optional): Sequence lengths for the batch entries
1517        in memory.  If provided, the memory tensor rows are masked with zeros
1518        for values past the respective sequence lengths.
1519      normalize: Python boolean. Whether to normalize the energy term.
1520      sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring
1521        for `_monotonic_probability_fn` for more information.
1522      sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
1523      score_bias_init: Initial value for score bias scalar. It's recommended to
1524        initialize this to a negative value when the length of the memory is
1525        large.
1526      mode: How to compute the attention distribution. Must be one of
1527        'recursive', 'parallel', or 'hard'. See the docstring for
1528        `tf.contrib.seq2seq.monotonic_attention` for more information.
1529      kernel_initializer: (optional), the name of the initializer for the
1530        attention kernel.
1531      dtype: The data type for the query and memory layers of the attention
1532        mechanism.
1533      name: Name to use when creating ops.
1534      **kwargs: Dictionary that contains other common arguments for layer
1535        creation.
1536    """
1537    # Set up the monotonic probability fn with supplied parameters
1538    if dtype is None:
1539      dtype = dtypes.float32
1540    wrapped_probability_fn = functools.partial(
1541        _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
1542        seed=sigmoid_noise_seed)
1543    query_layer = kwargs.pop("query_layer", None)
1544    if not query_layer:
1545      query_layer = layers.Dense(
1546          units, name="query_layer", use_bias=False, dtype=dtype)
1547    memory_layer = kwargs.pop("memory_layer", None)
1548    if not memory_layer:
1549      memory_layer = layers.Dense(
1550          units, name="memory_layer", use_bias=False, dtype=dtype)
1551    self.units = units
1552    self.normalize = normalize
1553    self.sigmoid_noise = sigmoid_noise
1554    self.sigmoid_noise_seed = sigmoid_noise_seed
1555    self.score_bias_init = score_bias_init
1556    self.mode = mode
1557    self.kernel_initializer = initializers.get(kernel_initializer)
1558    self.attention_v = None
1559    self.attention_score_bias = None
1560    self.attention_g = None
1561    self.attention_b = None
1562    super(BahdanauMonotonicAttentionV2, self).__init__(
1563        memory=memory,
1564        memory_sequence_length=memory_sequence_length,
1565        query_layer=query_layer,
1566        memory_layer=memory_layer,
1567        probability_fn=wrapped_probability_fn,
1568        name=name,
1569        dtype=dtype,
1570        **kwargs)
1571
1572  def build(self, input_shape):
1573    super(BahdanauMonotonicAttentionV2, self).build(input_shape)
1574    if self.attention_v is None:
1575      self.attention_v = self.add_weight(
1576          "attention_v", [self.units], dtype=self.dtype,
1577          initializer=self.kernel_initializer)
1578    if self.attention_score_bias is None:
1579      self.attention_score_bias = self.add_weight(
1580          "attention_score_bias", shape=(), dtype=self.dtype,
1581          initializer=init_ops.constant_initializer(
1582              self.score_bias_init, dtype=self.dtype))
1583    if self.normalize and self.attention_g is None and self.attention_b is None:
1584      self.attention_g = self.add_weight(
1585          "attention_g", dtype=self.dtype,
1586          initializer=init_ops.constant_initializer(
1587              math.sqrt((1. / self.units))),
1588          shape=())
1589      self.attention_b = self.add_weight(
1590          "attention_b", [self.units], dtype=self.dtype,
1591          initializer=init_ops.zeros_initializer())
1592    self.built = True
1593
1594  def _calculate_attention(self, query, state):
1595    """Score the query based on the keys and values.
1596
1597    Args:
1598      query: Tensor of dtype matching `self.values` and shape
1599        `[batch_size, query_depth]`.
1600      state: Tensor of dtype matching `self.values` and shape
1601        `[batch_size, alignments_size]`
1602        (`alignments_size` is memory's `max_time`).
1603
1604    Returns:
1605      alignments: Tensor of dtype matching `self.values` and shape
1606        `[batch_size, alignments_size]` (`alignments_size` is memory's
1607        `max_time`).
1608    """
1609    processed_query = self.query_layer(query) if self.query_layer else query
1610    score = _bahdanau_score(processed_query, self.keys, self.attention_v,
1611                            attention_g=self.attention_g,
1612                            attention_b=self.attention_b)
1613    score += self.attention_score_bias
1614    alignments = self.probability_fn(score, state)
1615    next_state = alignments
1616    return alignments, next_state
1617
1618  def get_config(self):
1619    config = {
1620        "units": self.units,
1621        "normalize": self.normalize,
1622        "sigmoid_noise": self.sigmoid_noise,
1623        "sigmoid_noise_seed": self.sigmoid_noise_seed,
1624        "score_bias_init": self.score_bias_init,
1625        "mode": self.mode,
1626        "kernel_initializer": initializers.serialize(self.kernel_initializer),
1627    }
1628    base_config = super(BahdanauMonotonicAttentionV2, self).get_config()
1629    return dict(list(base_config.items()) + list(config.items()))
1630
1631  @classmethod
1632  def from_config(cls, config, custom_objects=None):
1633    config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config(
1634        config, custom_objects=custom_objects)
1635    return cls(**config)
1636
1637
1638class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
1639  """Monotonic attention mechanism with Luong-style energy function.
1640
1641  This type of attention enforces a monotonic constraint on the attention
1642  distributions; that is once the model attends to a given point in the memory
1643  it can't attend to any prior points at subsequence output timesteps.  It
1644  achieves this by using the _monotonic_probability_fn instead of softmax to
1645  construct its attention distributions.  Otherwise, it is equivalent to
1646  LuongAttention.  This approach is proposed in
1647
1648  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1649  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1650  ICML 2017.  https://arxiv.org/abs/1704.00784
1651  """
1652
1653  def __init__(self,
1654               num_units,
1655               memory,
1656               memory_sequence_length=None,
1657               scale=False,
1658               score_mask_value=None,
1659               sigmoid_noise=0.,
1660               sigmoid_noise_seed=None,
1661               score_bias_init=0.,
1662               mode="parallel",
1663               dtype=None,
1664               name="LuongMonotonicAttention"):
1665    """Construct the Attention mechanism.
1666
1667    Args:
1668      num_units: The depth of the query mechanism.
1669      memory: The memory to query; usually the output of an RNN encoder.  This
1670        tensor should be shaped `[batch_size, max_time, ...]`.
1671      memory_sequence_length (optional): Sequence lengths for the batch entries
1672        in memory.  If provided, the memory tensor rows are masked with zeros
1673        for values past the respective sequence lengths.
1674      scale: Python boolean.  Whether to scale the energy term.
1675      score_mask_value: (optional): The mask value for score before passing into
1676        `probability_fn`. The default is -inf. Only used if
1677        `memory_sequence_length` is not None.
1678      sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
1679        for `_monotonic_probability_fn` for more information.
1680      sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
1681      score_bias_init: Initial value for score bias scalar.  It's recommended to
1682        initialize this to a negative value when the length of the memory is
1683        large.
1684      mode: How to compute the attention distribution.  Must be one of
1685        'recursive', 'parallel', or 'hard'.  See the docstring for
1686        `tf.contrib.seq2seq.monotonic_attention` for more information.
1687      dtype: The data type for the query and memory layers of the attention
1688        mechanism.
1689      name: Name to use when creating ops.
1690    """
1691    # Set up the monotonic probability fn with supplied parameters
1692    if dtype is None:
1693      dtype = dtypes.float32
1694    wrapped_probability_fn = functools.partial(
1695        _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
1696        seed=sigmoid_noise_seed)
1697    super(LuongMonotonicAttention, self).__init__(
1698        query_layer=None,
1699        memory_layer=layers_core.Dense(
1700            num_units, name="memory_layer", use_bias=False, dtype=dtype),
1701        memory=memory,
1702        probability_fn=wrapped_probability_fn,
1703        memory_sequence_length=memory_sequence_length,
1704        score_mask_value=score_mask_value,
1705        name=name)
1706    self._num_units = num_units
1707    self._scale = scale
1708    self._score_bias_init = score_bias_init
1709    self._name = name
1710
1711  def __call__(self, query, state):
1712    """Score the query based on the keys and values.
1713
1714    Args:
1715      query: Tensor of dtype matching `self.values` and shape
1716        `[batch_size, query_depth]`.
1717      state: Tensor of dtype matching `self.values` and shape
1718        `[batch_size, alignments_size]`
1719        (`alignments_size` is memory's `max_time`).
1720
1721    Returns:
1722      alignments: Tensor of dtype matching `self.values` and shape
1723        `[batch_size, alignments_size]` (`alignments_size` is memory's
1724        `max_time`).
1725    """
1726    with variable_scope.variable_scope(None, "luong_monotonic_attention",
1727                                       [query]):
1728      attention_g = None
1729      if self._scale:
1730        attention_g = variable_scope.get_variable(
1731            "attention_g", dtype=query.dtype,
1732            initializer=init_ops.ones_initializer, shape=())
1733      score = _luong_score(query, self._keys, attention_g)
1734      score_bias = variable_scope.get_variable(
1735          "attention_score_bias", dtype=query.dtype,
1736          initializer=self._score_bias_init)
1737      score += score_bias
1738    alignments = self._probability_fn(score, state)
1739    next_state = alignments
1740    return alignments, next_state
1741
1742
1743class LuongMonotonicAttentionV2(_BaseMonotonicAttentionMechanismV2):
1744  """Monotonic attention mechanism with Luong-style energy function.
1745
1746  This type of attention enforces a monotonic constraint on the attention
1747  distributions; that is once the model attends to a given point in the memory
1748  it can't attend to any prior points at subsequence output timesteps.  It
1749  achieves this by using the _monotonic_probability_fn instead of softmax to
1750  construct its attention distributions.  Otherwise, it is equivalent to
1751  LuongAttention.  This approach is proposed in
1752
1753  [Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
1754  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
1755  ICML 2017.](https://arxiv.org/abs/1704.00784)
1756  """
1757
1758  def __init__(self,
1759               units,
1760               memory,
1761               memory_sequence_length=None,
1762               scale=False,
1763               sigmoid_noise=0.,
1764               sigmoid_noise_seed=None,
1765               score_bias_init=0.,
1766               mode="parallel",
1767               dtype=None,
1768               name="LuongMonotonicAttention",
1769               **kwargs):
1770    """Construct the Attention mechanism.
1771
1772    Args:
1773      units: The depth of the query mechanism.
1774      memory: The memory to query; usually the output of an RNN encoder.  This
1775        tensor should be shaped `[batch_size, max_time, ...]`.
1776      memory_sequence_length: (optional): Sequence lengths for the batch entries
1777        in memory.  If provided, the memory tensor rows are masked with zeros
1778        for values past the respective sequence lengths.
1779      scale: Python boolean.  Whether to scale the energy term.
1780      sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
1781        for `_monotonic_probability_fn` for more information.
1782      sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
1783      score_bias_init: Initial value for score bias scalar.  It's recommended to
1784        initialize this to a negative value when the length of the memory is
1785        large.
1786      mode: How to compute the attention distribution.  Must be one of
1787        'recursive', 'parallel', or 'hard'.  See the docstring for
1788        `tf.contrib.seq2seq.monotonic_attention` for more information.
1789      dtype: The data type for the query and memory layers of the attention
1790        mechanism.
1791      name: Name to use when creating ops.
1792      **kwargs: Dictionary that contains other common arguments for layer
1793        creation.
1794    """
1795    # Set up the monotonic probability fn with supplied parameters
1796    if dtype is None:
1797      dtype = dtypes.float32
1798    wrapped_probability_fn = functools.partial(
1799        _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
1800        seed=sigmoid_noise_seed)
1801    memory_layer = kwargs.pop("memory_layer", None)
1802    if not memory_layer:
1803      memory_layer = layers.Dense(
1804          units, name="memory_layer", use_bias=False, dtype=dtype)
1805    self.units = units
1806    self.scale = scale
1807    self.sigmoid_noise = sigmoid_noise
1808    self.sigmoid_noise_seed = sigmoid_noise_seed
1809    self.score_bias_init = score_bias_init
1810    self.mode = mode
1811    self.attention_g = None
1812    self.attention_score_bias = None
1813    super(LuongMonotonicAttentionV2, self).__init__(
1814        memory=memory,
1815        memory_sequence_length=memory_sequence_length,
1816        query_layer=None,
1817        memory_layer=memory_layer,
1818        probability_fn=wrapped_probability_fn,
1819        name=name,
1820        dtype=dtype,
1821        **kwargs)
1822
1823  def build(self, input_shape):
1824    super(LuongMonotonicAttentionV2, self).build(input_shape)
1825    if self.scale and self.attention_g is None:
1826      self.attention_g = self.add_weight(
1827          "attention_g", initializer=init_ops.ones_initializer, shape=())
1828    if self.attention_score_bias is None:
1829      self.attention_score_bias = self.add_weight(
1830          "attention_score_bias", shape=(),
1831          initializer=init_ops.constant_initializer(
1832              self.score_bias_init, dtype=self.dtype))
1833    self.built = True
1834
1835  def _calculate_attention(self, query, state):
1836    """Score the query based on the keys and values.
1837
1838    Args:
1839      query: Tensor of dtype matching `self.values` and shape
1840        `[batch_size, query_depth]`.
1841      state: Tensor of dtype matching `self.values` and shape
1842        `[batch_size, alignments_size]`
1843        (`alignments_size` is memory's `max_time`).
1844
1845    Returns:
1846      alignments: Tensor of dtype matching `self.values` and shape
1847        `[batch_size, alignments_size]` (`alignments_size` is memory's
1848        `max_time`).
1849      next_state: Same as alignments
1850    """
1851    score = _luong_score(query, self.keys, self.attention_g)
1852    score += self.attention_score_bias
1853    alignments = self.probability_fn(score, state)
1854    next_state = alignments
1855    return alignments, next_state
1856
1857  def get_config(self):
1858    config = {
1859        "units": self.units,
1860        "scale": self.scale,
1861        "sigmoid_noise": self.sigmoid_noise,
1862        "sigmoid_noise_seed": self.sigmoid_noise_seed,
1863        "score_bias_init": self.score_bias_init,
1864        "mode": self.mode,
1865    }
1866    base_config = super(LuongMonotonicAttentionV2, self).get_config()
1867    return dict(list(base_config.items()) + list(config.items()))
1868
1869  @classmethod
1870  def from_config(cls, config, custom_objects=None):
1871    config = _BaseAttentionMechanismV2.deserialize_inner_layer_from_config(
1872        config, custom_objects=custom_objects)
1873    return cls(**config)
1874
1875
1876class AttentionWrapperState(
1877    collections.namedtuple("AttentionWrapperState",
1878                           ("cell_state", "attention", "time", "alignments",
1879                            "alignment_history", "attention_state"))):
1880  """`namedtuple` storing the state of a `AttentionWrapper`.
1881
1882  Contains:
1883
1884    - `cell_state`: The state of the wrapped `RNNCell` at the previous time
1885      step.
1886    - `attention`: The attention emitted at the previous time step.
1887    - `time`: int32 scalar containing the current time step.
1888    - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
1889       emitted at the previous time step for each attention mechanism.
1890    - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
1891       containing alignment matrices from all time steps for each attention
1892       mechanism. Call `stack()` on each to convert to a `Tensor`.
1893    - `attention_state`: A single or tuple of nested objects
1894       containing attention mechanism state for each attention mechanism.
1895       The objects may contain Tensors or TensorArrays.
1896  """
1897
1898  def clone(self, **kwargs):
1899    """Clone this object, overriding components provided by kwargs.
1900
1901    The new state fields' shape must match original state fields' shape. This
1902    will be validated, and original fields' shape will be propagated to new
1903    fields.
1904
1905    Example:
1906
1907    ```python
1908    initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
1909    initial_state = initial_state.clone(cell_state=encoder_state)
1910    ```
1911
1912    Args:
1913      **kwargs: Any properties of the state object to replace in the returned
1914        `AttentionWrapperState`.
1915
1916    Returns:
1917      A new `AttentionWrapperState` whose properties are the same as
1918      this one, except any overridden properties as provided in `kwargs`.
1919    """
1920    def with_same_shape(old, new):
1921      """Check and set new tensor's shape."""
1922      if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
1923        if not context.executing_eagerly():
1924          return tensor_util.with_same_shape(old, new)
1925        else:
1926          if old.shape.as_list() != new.shape.as_list():
1927            raise ValueError("The shape of the AttentionWrapperState is "
1928                             "expected to be same as the one to clone. "
1929                             "self.shape: %s, input.shape: %s" %
1930                             (old.shape, new.shape))
1931          return new
1932      return new
1933
1934    return nest.map_structure(
1935        with_same_shape,
1936        self,
1937        super(AttentionWrapperState, self)._replace(**kwargs))
1938
1939
1940def _prepare_memory(memory, memory_sequence_length=None, memory_mask=None,
1941                    check_inner_dims_defined=True):
1942  """Convert to tensor and possibly mask `memory`.
1943
1944  Args:
1945    memory: `Tensor`, shaped `[batch_size, max_time, ...]`.
1946    memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`.
1947    memory_mask: `boolean` tensor with shape [batch_size, max_time]. The memory
1948      should be skipped when the corresponding mask is False.
1949    check_inner_dims_defined: Python boolean.  If `True`, the `memory`
1950      argument's shape is checked to ensure all but the two outermost
1951      dimensions are fully defined.
1952
1953  Returns:
1954    A (possibly masked), checked, new `memory`.
1955
1956  Raises:
1957    ValueError: If `check_inner_dims_defined` is `True` and not
1958      `memory.shape[2:].is_fully_defined()`.
1959  """
1960  memory = nest.map_structure(
1961      lambda m: ops.convert_to_tensor(m, name="memory"), memory)
1962  if memory_sequence_length is not None and memory_mask is not None:
1963    raise ValueError("memory_sequence_length and memory_mask can't be provided "
1964                     "at same time.")
1965  if memory_sequence_length is not None:
1966    memory_sequence_length = ops.convert_to_tensor(
1967        memory_sequence_length, name="memory_sequence_length")
1968  if check_inner_dims_defined:
1969    def _check_dims(m):
1970      if not m.get_shape()[2:].is_fully_defined():
1971        raise ValueError("Expected memory %s to have fully defined inner dims, "
1972                         "but saw shape: %s" % (m.name, m.get_shape()))
1973    nest.map_structure(_check_dims, memory)
1974  if memory_sequence_length is None and memory_mask is None:
1975    return memory
1976  elif memory_sequence_length is not None:
1977    seq_len_mask = array_ops.sequence_mask(
1978        memory_sequence_length,
1979        maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
1980        dtype=nest.flatten(memory)[0].dtype)
1981  else:
1982    # For memory_mask is not None
1983    seq_len_mask = math_ops.cast(
1984        memory_mask, dtype=nest.flatten(memory)[0].dtype)
1985  def _maybe_mask(m, seq_len_mask):
1986    """Mask the memory based on the memory mask."""
1987    rank = m.get_shape().ndims
1988    rank = rank if rank is not None else array_ops.rank(m)
1989    extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
1990    seq_len_mask = array_ops.reshape(
1991        seq_len_mask,
1992        array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0))
1993    return m * seq_len_mask
1994
1995  return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
1996
1997
1998def _maybe_mask_score(score, memory_sequence_length=None, memory_mask=None,
1999                      score_mask_value=None):
2000  """Mask the attention score based on the masks."""
2001  if memory_sequence_length is None and memory_mask is None:
2002    return score
2003  if memory_sequence_length is not None and memory_mask is not None:
2004    raise ValueError("memory_sequence_length and memory_mask can't be provided "
2005                     "at same time.")
2006  if memory_sequence_length is not None:
2007    message = "All values in memory_sequence_length must greater than zero."
2008    with ops.control_dependencies(
2009        [check_ops.assert_positive(memory_sequence_length, message=message)]):
2010      memory_mask = array_ops.sequence_mask(
2011          memory_sequence_length, maxlen=array_ops.shape(score)[1])
2012  score_mask_values = score_mask_value * array_ops.ones_like(score)
2013  return array_ops.where(memory_mask, score, score_mask_values)
2014
2015
2016def hardmax(logits, name=None):
2017  """Returns batched one-hot vectors.
2018
2019  The depth index containing the `1` is that of the maximum logit value.
2020
2021  Args:
2022    logits: A batch tensor of logit values.
2023    name: Name to use when creating ops.
2024  Returns:
2025    A batched one-hot tensor.
2026  """
2027  with ops.name_scope(name, "Hardmax", [logits]):
2028    logits = ops.convert_to_tensor(logits, name="logits")
2029    if tensor_shape.dimension_value(logits.get_shape()[-1]) is not None:
2030      depth = tensor_shape.dimension_value(logits.get_shape()[-1])
2031    else:
2032      depth = array_ops.shape(logits)[-1]
2033    return array_ops.one_hot(
2034        math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
2035
2036
2037def _compute_attention(attention_mechanism, cell_output, attention_state,
2038                       attention_layer):
2039  """Computes the attention and alignments for a given attention_mechanism."""
2040  if isinstance(attention_mechanism, _BaseAttentionMechanismV2):
2041    alignments, next_attention_state = attention_mechanism(
2042        [cell_output, attention_state])
2043  else:
2044    # For other class, assume they are following _BaseAttentionMechanism, which
2045    # takes query and state as separate parameter.
2046    alignments, next_attention_state = attention_mechanism(
2047        cell_output, state=attention_state)
2048
2049  # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
2050  expanded_alignments = array_ops.expand_dims(alignments, 1)
2051  # Context is the inner product of alignments and values along the
2052  # memory time dimension.
2053  # alignments shape is
2054  #   [batch_size, 1, memory_time]
2055  # attention_mechanism.values shape is
2056  #   [batch_size, memory_time, memory_size]
2057  # the batched matmul is over memory_time, so the output shape is
2058  #   [batch_size, 1, memory_size].
2059  # we then squeeze out the singleton dim.
2060  context_ = math_ops.matmul(expanded_alignments, attention_mechanism.values)
2061  context_ = array_ops.squeeze(context_, [1])
2062
2063  if attention_layer is not None:
2064    attention = attention_layer(array_ops.concat([cell_output, context_], 1))
2065  else:
2066    attention = context_
2067
2068  return attention, alignments, next_attention_state
2069
2070
2071class AttentionWrapper(rnn_cell_impl.RNNCell):
2072  """Wraps another `RNNCell` with attention.
2073  """
2074
2075  def __init__(self,
2076               cell,
2077               attention_mechanism,
2078               attention_layer_size=None,
2079               alignment_history=False,
2080               cell_input_fn=None,
2081               output_attention=True,
2082               initial_cell_state=None,
2083               name=None,
2084               attention_layer=None,
2085               attention_fn=None):
2086    """Construct the `AttentionWrapper`.
2087
2088    **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
2089    `AttentionWrapper`, then you must ensure that:
2090
2091    - The encoder output has been tiled to `beam_width` via
2092      `tf.contrib.seq2seq.tile_batch` (NOT `tf.tile`).
2093    - The `batch_size` argument passed to the `zero_state` method of this
2094      wrapper is equal to `true_batch_size * beam_width`.
2095    - The initial state created with `zero_state` above contains a
2096      `cell_state` value containing properly tiled final state from the
2097      encoder.
2098
2099    An example:
2100
2101    ```
2102    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
2103        encoder_outputs, multiplier=beam_width)
2104    tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
2105        encoder_final_state, multiplier=beam_width)
2106    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
2107        sequence_length, multiplier=beam_width)
2108    attention_mechanism = MyFavoriteAttentionMechanism(
2109        num_units=attention_depth,
2110        memory=tiled_inputs,
2111        memory_sequence_length=tiled_sequence_length)
2112    attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
2113    decoder_initial_state = attention_cell.zero_state(
2114        dtype, batch_size=true_batch_size * beam_width)
2115    decoder_initial_state = decoder_initial_state.clone(
2116        cell_state=tiled_encoder_final_state)
2117    ```
2118
2119    Args:
2120      cell: An instance of `RNNCell`.
2121      attention_mechanism: A list of `AttentionMechanism` instances or a single
2122        instance.
2123      attention_layer_size: A list of Python integers or a single Python
2124        integer, the depth of the attention (output) layer(s). If None
2125        (default), use the context as attention at each time step. Otherwise,
2126        feed the context and cell output into the attention layer to generate
2127        attention at each time step. If attention_mechanism is a list,
2128        attention_layer_size must be a list of the same length. If
2129        attention_layer is set, this must be None. If attention_fn is set,
2130        it must guaranteed that the outputs of attention_fn also meet the
2131        above requirements.
2132      alignment_history: Python boolean, whether to store alignment history
2133        from all time steps in the final output state (currently stored as a
2134        time major `TensorArray` on which you must call `stack()`).
2135      cell_input_fn: (optional) A `callable`.  The default is:
2136        `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
2137      output_attention: Python bool.  If `True` (default), the output at each
2138        time step is the attention value.  This is the behavior of Luong-style
2139        attention mechanisms.  If `False`, the output at each time step is
2140        the output of `cell`.  This is the behavior of Bhadanau-style
2141        attention mechanisms.  In both cases, the `attention` tensor is
2142        propagated to the next time step via the state and is used there.
2143        This flag only controls whether the attention mechanism is propagated
2144        up to the next cell in an RNN stack or to the top RNN output.
2145      initial_cell_state: The initial state value to use for the cell when
2146        the user calls `zero_state()`.  Note that if this value is provided
2147        now, and the user uses a `batch_size` argument of `zero_state` which
2148        does not match the batch size of `initial_cell_state`, proper
2149        behavior is not guaranteed.
2150      name: Name to use when creating ops.
2151      attention_layer: A list of `tf.layers.Layer` instances or a
2152        single `tf.layers.Layer` instance taking the context and cell output as
2153        inputs to generate attention at each time step. If None (default), use
2154        the context as attention at each time step. If attention_mechanism is a
2155        list, attention_layer must be a list of the same length. If
2156        attention_layers_size is set, this must be None.
2157      attention_fn: An optional callable function that allows users to provide
2158        their own customized attention function, which takes input
2159        (attention_mechanism, cell_output, attention_state, attention_layer) and
2160        outputs (attention, alignments, next_attention_state). If provided,
2161        the attention_layer_size should be the size of the outputs of
2162        attention_fn.
2163
2164    Raises:
2165      TypeError: `attention_layer_size` is not None and (`attention_mechanism`
2166        is a list but `attention_layer_size` is not; or vice versa).
2167      ValueError: if `attention_layer_size` is not None, `attention_mechanism`
2168        is a list, and its length does not match that of `attention_layer_size`;
2169        if `attention_layer_size` and `attention_layer` are set simultaneously.
2170    """
2171    super(AttentionWrapper, self).__init__(name=name)
2172    rnn_cell_impl.assert_like_rnncell("cell", cell)
2173    if isinstance(attention_mechanism, (list, tuple)):
2174      self._is_multi = True
2175      attention_mechanisms = attention_mechanism
2176      for attention_mechanism in attention_mechanisms:
2177        if not isinstance(attention_mechanism, AttentionMechanism):
2178          raise TypeError(
2179              "attention_mechanism must contain only instances of "
2180              "AttentionMechanism, saw type: %s"
2181              % type(attention_mechanism).__name__)
2182    else:
2183      self._is_multi = False
2184      if not isinstance(attention_mechanism, AttentionMechanism):
2185        raise TypeError(
2186            "attention_mechanism must be an AttentionMechanism or list of "
2187            "multiple AttentionMechanism instances, saw type: %s"
2188            % type(attention_mechanism).__name__)
2189      attention_mechanisms = (attention_mechanism,)
2190
2191    if cell_input_fn is None:
2192      cell_input_fn = (
2193          lambda inputs, attention: array_ops.concat([inputs, attention], -1))
2194    else:
2195      if not callable(cell_input_fn):
2196        raise TypeError(
2197            "cell_input_fn must be callable, saw type: %s"
2198            % type(cell_input_fn).__name__)
2199
2200    if attention_layer_size is not None and attention_layer is not None:
2201      raise ValueError("Only one of attention_layer_size and attention_layer "
2202                       "should be set")
2203
2204    if attention_layer_size is not None:
2205      attention_layer_sizes = tuple(
2206          attention_layer_size
2207          if isinstance(attention_layer_size, (list, tuple))
2208          else (attention_layer_size,))
2209      if len(attention_layer_sizes) != len(attention_mechanisms):
2210        raise ValueError(
2211            "If provided, attention_layer_size must contain exactly one "
2212            "integer per attention_mechanism, saw: %d vs %d"
2213            % (len(attention_layer_sizes), len(attention_mechanisms)))
2214      self._attention_layers = tuple(
2215          layers_core.Dense(
2216              attention_layer_size,
2217              name="attention_layer",
2218              use_bias=False,
2219              dtype=attention_mechanisms[i].dtype)
2220          for i, attention_layer_size in enumerate(attention_layer_sizes))
2221      self._attention_layer_size = sum(attention_layer_sizes)
2222    elif attention_layer is not None:
2223      self._attention_layers = tuple(
2224          attention_layer
2225          if isinstance(attention_layer, (list, tuple))
2226          else (attention_layer,))
2227      if len(self._attention_layers) != len(attention_mechanisms):
2228        raise ValueError(
2229            "If provided, attention_layer must contain exactly one "
2230            "layer per attention_mechanism, saw: %d vs %d"
2231            % (len(self._attention_layers), len(attention_mechanisms)))
2232      self._attention_layer_size = sum(
2233          tensor_shape.dimension_value(layer.compute_output_shape(
2234              [None,
2235               cell.output_size + tensor_shape.dimension_value(
2236                   mechanism.values.shape[-1])])[-1])
2237          for layer, mechanism in zip(
2238              self._attention_layers, attention_mechanisms))
2239    else:
2240      self._attention_layers = None
2241      self._attention_layer_size = sum(
2242          tensor_shape.dimension_value(attention_mechanism.values.shape[-1])
2243          for attention_mechanism in attention_mechanisms)
2244
2245    if attention_fn is None:
2246      attention_fn = _compute_attention
2247    self._attention_fn = attention_fn
2248
2249    self._cell = cell
2250    self._attention_mechanisms = attention_mechanisms
2251    self._cell_input_fn = cell_input_fn
2252    self._output_attention = output_attention
2253    self._alignment_history = alignment_history
2254    with ops.name_scope(name, "AttentionWrapperInit"):
2255      if initial_cell_state is None:
2256        self._initial_cell_state = None
2257      else:
2258        final_state_tensor = nest.flatten(initial_cell_state)[-1]
2259        state_batch_size = (
2260            tensor_shape.dimension_value(final_state_tensor.shape[0])
2261            or array_ops.shape(final_state_tensor)[0])
2262        error_message = (
2263            "When constructing AttentionWrapper %s: " % self._base_name +
2264            "Non-matching batch sizes between the memory "
2265            "(encoder output) and initial_cell_state.  Are you using "
2266            "the BeamSearchDecoder?  You may need to tile your initial state "
2267            "via the tf.contrib.seq2seq.tile_batch function with argument "
2268            "multiple=beam_width.")
2269        with ops.control_dependencies(
2270            self._batch_size_checks(state_batch_size, error_message)):
2271          self._initial_cell_state = nest.map_structure(
2272              lambda s: array_ops.identity(s, name="check_initial_cell_state"),
2273              initial_cell_state)
2274
2275  def _batch_size_checks(self, batch_size, error_message):
2276    return [check_ops.assert_equal(batch_size,
2277                                   attention_mechanism.batch_size,
2278                                   message=error_message)
2279            for attention_mechanism in self._attention_mechanisms]
2280
2281  def _item_or_tuple(self, seq):
2282    """Returns `seq` as tuple or the singular element.
2283
2284    Which is returned is determined by how the AttentionMechanism(s) were passed
2285    to the constructor.
2286
2287    Args:
2288      seq: A non-empty sequence of items or generator.
2289
2290    Returns:
2291       Either the values in the sequence as a tuple if AttentionMechanism(s)
2292       were passed to the constructor as a sequence or the singular element.
2293    """
2294    t = tuple(seq)
2295    if self._is_multi:
2296      return t
2297    else:
2298      return t[0]
2299
2300  @property
2301  def output_size(self):
2302    if self._output_attention:
2303      return self._attention_layer_size
2304    else:
2305      return self._cell.output_size
2306
2307  @property
2308  def state_size(self):
2309    """The `state_size` property of `AttentionWrapper`.
2310
2311    Returns:
2312      An `AttentionWrapperState` tuple containing shapes used by this object.
2313    """
2314    return AttentionWrapperState(
2315        cell_state=self._cell.state_size,
2316        time=tensor_shape.TensorShape([]),
2317        attention=self._attention_layer_size,
2318        alignments=self._item_or_tuple(
2319            a.alignments_size for a in self._attention_mechanisms),
2320        attention_state=self._item_or_tuple(
2321            a.state_size for a in self._attention_mechanisms),
2322        alignment_history=self._item_or_tuple(
2323            a.alignments_size if self._alignment_history else ()
2324            for a in self._attention_mechanisms))  # sometimes a TensorArray
2325
2326  def zero_state(self, batch_size, dtype):
2327    """Return an initial (zero) state tuple for this `AttentionWrapper`.
2328
2329    **NOTE** Please see the initializer documentation for details of how
2330    to call `zero_state` if using an `AttentionWrapper` with a
2331    `BeamSearchDecoder`.
2332
2333    Args:
2334      batch_size: `0D` integer tensor: the batch size.
2335      dtype: The internal state data type.
2336
2337    Returns:
2338      An `AttentionWrapperState` tuple containing zeroed out tensors and,
2339      possibly, empty `TensorArray` objects.
2340
2341    Raises:
2342      ValueError: (or, possibly at runtime, InvalidArgument), if
2343        `batch_size` does not match the output size of the encoder passed
2344        to the wrapper object at initialization time.
2345    """
2346    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
2347      if self._initial_cell_state is not None:
2348        cell_state = self._initial_cell_state
2349      else:
2350        cell_state = self._cell.get_initial_state(batch_size=batch_size,
2351                                                  dtype=dtype)
2352      error_message = (
2353          "When calling zero_state of AttentionWrapper %s: " % self._base_name +
2354          "Non-matching batch sizes between the memory "
2355          "(encoder output) and the requested batch size.  Are you using "
2356          "the BeamSearchDecoder?  If so, make sure your encoder output has "
2357          "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
2358          "the batch_size= argument passed to zero_state is "
2359          "batch_size * beam_width.")
2360      with ops.control_dependencies(
2361          self._batch_size_checks(batch_size, error_message)):
2362        cell_state = nest.map_structure(
2363            lambda s: array_ops.identity(s, name="checked_cell_state"),
2364            cell_state)
2365      initial_alignments = [
2366          attention_mechanism.initial_alignments(batch_size, dtype)
2367          for attention_mechanism in self._attention_mechanisms]
2368      return AttentionWrapperState(
2369          cell_state=cell_state,
2370          time=array_ops.zeros([], dtype=dtypes.int32),
2371          attention=_zero_state_tensors(self._attention_layer_size, batch_size,
2372                                        dtype),
2373          alignments=self._item_or_tuple(initial_alignments),
2374          attention_state=self._item_or_tuple(
2375              attention_mechanism.initial_state(batch_size, dtype)
2376              for attention_mechanism in self._attention_mechanisms),
2377          alignment_history=self._item_or_tuple(
2378              tensor_array_ops.TensorArray(
2379                  dtype,
2380                  size=0,
2381                  dynamic_size=True,
2382                  element_shape=alignment.shape)
2383              if self._alignment_history else ()
2384              for alignment in initial_alignments))
2385
2386  def call(self, inputs, state):
2387    """Perform a step of attention-wrapped RNN.
2388
2389    - Step 1: Mix the `inputs` and previous step's `attention` output via
2390      `cell_input_fn`.
2391    - Step 2: Call the wrapped `cell` with this input and its previous state.
2392    - Step 3: Score the cell's output with `attention_mechanism`.
2393    - Step 4: Calculate the alignments by passing the score through the
2394      `normalizer`.
2395    - Step 5: Calculate the context vector as the inner product between the
2396      alignments and the attention_mechanism's values (memory).
2397    - Step 6: Calculate the attention output by concatenating the cell output
2398      and context through the attention layer (a linear layer with
2399      `attention_layer_size` outputs).
2400
2401    Args:
2402      inputs: (Possibly nested tuple of) Tensor, the input at this time step.
2403      state: An instance of `AttentionWrapperState` containing
2404        tensors from the previous time step.
2405
2406    Returns:
2407      A tuple `(attention_or_cell_output, next_state)`, where:
2408
2409      - `attention_or_cell_output` depending on `output_attention`.
2410      - `next_state` is an instance of `AttentionWrapperState`
2411         containing the state calculated at this time step.
2412
2413    Raises:
2414      TypeError: If `state` is not an instance of `AttentionWrapperState`.
2415    """
2416    if not isinstance(state, AttentionWrapperState):
2417      raise TypeError("Expected state to be instance of AttentionWrapperState. "
2418                      "Received type %s instead."  % type(state))
2419
2420    # Step 1: Calculate the true inputs to the cell based on the
2421    # previous attention value.
2422    cell_inputs = self._cell_input_fn(inputs, state.attention)
2423    cell_state = state.cell_state
2424    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
2425
2426    cell_batch_size = (
2427        tensor_shape.dimension_value(cell_output.shape[0]) or
2428        array_ops.shape(cell_output)[0])
2429    error_message = (
2430        "When applying AttentionWrapper %s: " % self.name +
2431        "Non-matching batch sizes between the memory "
2432        "(encoder output) and the query (decoder output).  Are you using "
2433        "the BeamSearchDecoder?  You may need to tile your memory input via "
2434        "the tf.contrib.seq2seq.tile_batch function with argument "
2435        "multiple=beam_width.")
2436    with ops.control_dependencies(
2437        self._batch_size_checks(cell_batch_size, error_message)):
2438      cell_output = array_ops.identity(
2439          cell_output, name="checked_cell_output")
2440
2441    if self._is_multi:
2442      previous_attention_state = state.attention_state
2443      previous_alignment_history = state.alignment_history
2444    else:
2445      previous_attention_state = [state.attention_state]
2446      previous_alignment_history = [state.alignment_history]
2447
2448    all_alignments = []
2449    all_attentions = []
2450    all_attention_states = []
2451    maybe_all_histories = []
2452    for i, attention_mechanism in enumerate(self._attention_mechanisms):
2453      attention, alignments, next_attention_state = self._attention_fn(
2454          attention_mechanism, cell_output, previous_attention_state[i],
2455          self._attention_layers[i] if self._attention_layers else None)
2456      alignment_history = previous_alignment_history[i].write(
2457          state.time, alignments) if self._alignment_history else ()
2458
2459      all_attention_states.append(next_attention_state)
2460      all_alignments.append(alignments)
2461      all_attentions.append(attention)
2462      maybe_all_histories.append(alignment_history)
2463
2464    attention = array_ops.concat(all_attentions, 1)
2465    next_state = AttentionWrapperState(
2466        time=state.time + 1,
2467        cell_state=next_cell_state,
2468        attention=attention,
2469        attention_state=self._item_or_tuple(all_attention_states),
2470        alignments=self._item_or_tuple(all_alignments),
2471        alignment_history=self._item_or_tuple(maybe_all_histories))
2472
2473    if self._output_attention:
2474      return attention, next_state
2475    else:
2476      return cell_output, next_state
2477