• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Attention layers that can be used in sequence DNN/CNN models.
16
17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2.
18Attention is formed by three tensors: Query, Key and Value.
19"""
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.keras import backend
25from tensorflow.python.keras.engine.base_layer import Layer
26from tensorflow.python.keras.utils import control_flow_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import init_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn
31from tensorflow.python.util.tf_export import keras_export
32
33
34class BaseDenseAttention(Layer):
35  """Base Attention class for Dense networks.
36
37  This class is suitable for Dense or CNN networks, and not for RNN networks.
38
39  Implementations of attention mechanisms should inherit from this class, and
40  reuse the `apply_attention_scores()` method.
41
42  Args:
43    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
44      that position `i` cannot attend to positions `j > i`. This prevents the
45      flow of information from the future towards the past.
46    dropout: Float between 0 and 1. Fraction of the units to drop for the
47      attention scores.
48
49  Call Args:
50
51    inputs: List of the following tensors:
52      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
53      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
54      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
55        given, will use `value` for both `key` and `value`, which is the
56        most common case.
57    mask: List of the following tensors:
58      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
59        If given, the output will be zero at the positions where
60        `mask==False`.
61      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
62        If given, will apply the mask such that values at positions where
63        `mask==False` do not contribute to the result.
64    training: Python boolean indicating whether the layer should behave in
65      training mode (adding dropout) or in inference mode (no dropout).
66    return_attention_scores: bool, it `True`, returns the attention scores
67      (after masking and softmax) as an additional output argument.
68
69  Output:
70
71    Attention outputs of shape `[batch_size, Tq, dim]`.
72    [Optional] Attention scores after masking and softmax with shape
73      `[batch_size, Tq, Tv]`.
74  """
75
76  def __init__(self, causal=False, dropout=0.0,
77               **kwargs):
78    super(BaseDenseAttention, self).__init__(**kwargs)
79    self.causal = causal
80    self.dropout = dropout
81    self.supports_masking = True
82
83  def _calculate_scores(self, query, key):
84    """Calculates attention scores.
85
86    Args:
87      query: Query tensor of shape `[batch_size, Tq, dim]`.
88      key: Key tensor of shape `[batch_size, Tv, dim]`.
89
90    Returns:
91      Tensor of shape `[batch_size, Tq, Tv]`.
92    """
93    return NotImplementedError
94
95  def _apply_scores(self, scores, value, scores_mask=None, training=None):
96    """Applies attention scores to the given value tensor.
97
98    To use this method in your attention layer, follow the steps:
99
100    * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape
101      `[batch_size, Tv]` to calculate the attention `scores`.
102    * Pass `scores` and `value` tensors to this method. The method applies
103      `scores_mask`, calculates `attention_distribution = softmax(scores)`, then
104      returns `matmul(attention_distribution, value).
105    * Apply `query_mask` and return the result.
106
107    Args:
108      scores: Scores float tensor of shape `[batch_size, Tq, Tv]`.
109      value: Value tensor of shape `[batch_size, Tv, dim]`.
110      scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or
111        `[batch_size, Tq, Tv]`. If given, scores at positions where
112        `scores_mask==False` do not contribute to the result. It must contain
113        at least one `True` value in each line along the last dimension.
114      training: Python boolean indicating whether the layer should behave in
115        training mode (adding dropout) or in inference mode (no dropout).
116
117    Returns:
118      Tensor of shape `[batch_size, Tq, dim]`.
119      Attention scores after masking and softmax with shape
120        `[batch_size, Tq, Tv]`.
121    """
122    if scores_mask is not None:
123      padding_mask = math_ops.logical_not(scores_mask)
124      # Bias so padding positions do not contribute to attention distribution.
125      # Note 65504. is the max float16 value.
126      if scores.dtype is dtypes.float16:
127        scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype)
128      else:
129        scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype)
130    if training is None:
131      training = backend.learning_phase()
132    weights = nn.softmax(scores)
133
134    def dropped_weights():
135      return nn.dropout(weights, rate=self.dropout)
136
137    weights = control_flow_util.smart_cond(training, dropped_weights,
138                                           lambda: array_ops.identity(weights))
139    return math_ops.matmul(weights, value), weights
140
141  # TODO(b/125916026): Consider exposing a __call__ method with named args.
142  def call(self,
143           inputs,
144           mask=None,
145           training=None,
146           return_attention_scores=False):
147    self._validate_call_args(inputs=inputs, mask=mask)
148    q = inputs[0]
149    v = inputs[1]
150    k = inputs[2] if len(inputs) > 2 else v
151    q_mask = mask[0] if mask else None
152    v_mask = mask[1] if mask else None
153    scores = self._calculate_scores(query=q, key=k)
154    if v_mask is not None:
155      # Mask of shape [batch_size, 1, Tv].
156      v_mask = array_ops.expand_dims(v_mask, axis=-2)
157    if self.causal:
158      # Creates a lower triangular mask, so position i cannot attend to
159      # positions j>i. This prevents the flow of information from the future
160      # into the past.
161      scores_shape = array_ops.shape(scores)
162      # causal_mask_shape = [1, Tq, Tv].
163      causal_mask_shape = array_ops.concat(
164          [array_ops.ones_like(scores_shape[:-2]), scores_shape[-2:]],
165          axis=0)
166      causal_mask = _lower_triangular_mask(causal_mask_shape)
167    else:
168      causal_mask = None
169    scores_mask = _merge_masks(v_mask, causal_mask)
170    result, attention_scores = self._apply_scores(
171        scores=scores, value=v, scores_mask=scores_mask, training=training)
172    if q_mask is not None:
173      # Mask of shape [batch_size, Tq, 1].
174      q_mask = array_ops.expand_dims(q_mask, axis=-1)
175      result *= math_ops.cast(q_mask, dtype=result.dtype)
176    if return_attention_scores:
177      return result, attention_scores
178    return result
179
180  def compute_mask(self, inputs, mask=None):
181    self._validate_call_args(inputs=inputs, mask=mask)
182    if mask:
183      q_mask = mask[0]
184      if q_mask is None:
185        return None
186      return ops.convert_to_tensor_v2_with_dispatch(q_mask)
187    return None
188
189  def _validate_call_args(self, inputs, mask):
190    """Validates arguments of the call method."""
191    class_name = self.__class__.__name__
192    if not isinstance(inputs, list):
193      raise ValueError(
194          '{} layer must be called on a list of inputs, namely [query, value] '
195          'or [query, value, key].'.format(class_name))
196    if len(inputs) < 2 or len(inputs) > 3:
197      raise ValueError(
198          '{} layer accepts inputs list of length 2 or 3, '
199          'namely [query, value] or [query, value, key]. '
200          'Given length: {}'.format(class_name, len(inputs)))
201    if mask:
202      if not isinstance(mask, list):
203        raise ValueError(
204            '{} layer mask must be a list, '
205            'namely [query_mask, value_mask].'.format(class_name))
206      if len(mask) < 2 or len(mask) > len(inputs):
207        raise ValueError(
208            '{} layer mask must be a list of length 2, namely [query_mask, '
209            'value_mask]. Given length: {}'.format(class_name, len(mask)))
210
211  def get_config(self):
212    config = {
213        'causal': self.causal,
214        'dropout': self.dropout,
215    }
216    base_config = super(BaseDenseAttention, self).get_config()
217    return dict(list(base_config.items()) + list(config.items()))
218
219
220@keras_export('keras.layers.Attention')
221class Attention(BaseDenseAttention):
222  """Dot-product attention layer, a.k.a. Luong-style attention.
223
224  Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
225  shape `[batch_size, Tv, dim]` and `key` tensor of shape
226  `[batch_size, Tv, dim]`. The calculation follows the steps:
227
228  1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
229     product: `scores = tf.matmul(query, key, transpose_b=True)`.
230  2. Use scores to calculate a distribution with shape
231     `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
232  3. Use `distribution` to create a linear combination of `value` with
233     shape `[batch_size, Tq, dim]`:
234     `return tf.matmul(distribution, value)`.
235
236  Args:
237    use_scale: If `True`, will create a scalar variable to scale the attention
238      scores.
239    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
240      that position `i` cannot attend to positions `j > i`. This prevents the
241      flow of information from the future towards the past.
242    dropout: Float between 0 and 1. Fraction of the units to drop for the
243      attention scores.
244
245  Call Args:
246
247    inputs: List of the following tensors:
248      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
249      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
250      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
251        given, will use `value` for both `key` and `value`, which is the
252        most common case.
253    mask: List of the following tensors:
254      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
255        If given, the output will be zero at the positions where
256        `mask==False`.
257      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
258        If given, will apply the mask such that values at positions where
259        `mask==False` do not contribute to the result.
260    return_attention_scores: bool, it `True`, returns the attention scores
261      (after masking and softmax) as an additional output argument.
262    training: Python boolean indicating whether the layer should behave in
263      training mode (adding dropout) or in inference mode (no dropout).
264
265  Output:
266
267    Attention outputs of shape `[batch_size, Tq, dim]`.
268    [Optional] Attention scores after masking and softmax with shape
269      `[batch_size, Tq, Tv]`.
270
271  The meaning of `query`, `value` and `key` depend on the application. In the
272  case of text similarity, for example, `query` is the sequence embeddings of
273  the first piece of text and `value` is the sequence embeddings of the second
274  piece of text. `key` is usually the same tensor as `value`.
275
276  Here is a code example for using `Attention` in a CNN+Attention network:
277
278  ```python
279  # Variable-length int sequences.
280  query_input = tf.keras.Input(shape=(None,), dtype='int32')
281  value_input = tf.keras.Input(shape=(None,), dtype='int32')
282
283  # Embedding lookup.
284  token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
285  # Query embeddings of shape [batch_size, Tq, dimension].
286  query_embeddings = token_embedding(query_input)
287  # Value embeddings of shape [batch_size, Tv, dimension].
288  value_embeddings = token_embedding(value_input)
289
290  # CNN layer.
291  cnn_layer = tf.keras.layers.Conv1D(
292      filters=100,
293      kernel_size=4,
294      # Use 'same' padding so outputs have the same shape as inputs.
295      padding='same')
296  # Query encoding of shape [batch_size, Tq, filters].
297  query_seq_encoding = cnn_layer(query_embeddings)
298  # Value encoding of shape [batch_size, Tv, filters].
299  value_seq_encoding = cnn_layer(value_embeddings)
300
301  # Query-value attention of shape [batch_size, Tq, filters].
302  query_value_attention_seq = tf.keras.layers.Attention()(
303      [query_seq_encoding, value_seq_encoding])
304
305  # Reduce over the sequence axis to produce encodings of shape
306  # [batch_size, filters].
307  query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
308      query_seq_encoding)
309  query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
310      query_value_attention_seq)
311
312  # Concatenate query and document encodings to produce a DNN input layer.
313  input_layer = tf.keras.layers.Concatenate()(
314      [query_encoding, query_value_attention])
315
316  # Add DNN layers, and create Model.
317  # ...
318  ```
319  """
320
321  def __init__(self, use_scale=False, **kwargs):
322    super(Attention, self).__init__(**kwargs)
323    self.use_scale = use_scale
324
325  def build(self, input_shape):
326    """Creates scale variable if use_scale==True."""
327    if self.use_scale:
328      self.scale = self.add_weight(
329          name='scale',
330          shape=(),
331          initializer=init_ops.ones_initializer(),
332          dtype=self.dtype,
333          trainable=True)
334    else:
335      self.scale = None
336    super(Attention, self).build(input_shape)
337
338  def _calculate_scores(self, query, key):
339    """Calculates attention scores as a query-key dot product.
340
341    Args:
342      query: Query tensor of shape `[batch_size, Tq, dim]`.
343      key: Key tensor of shape `[batch_size, Tv, dim]`.
344    Returns:
345      Tensor of shape `[batch_size, Tq, Tv]`.
346    """
347    scores = math_ops.matmul(query, key, transpose_b=True)
348    if self.scale is not None:
349      scores *= self.scale
350    return scores
351
352  def get_config(self):
353    config = {'use_scale': self.use_scale}
354    base_config = super(Attention, self).get_config()
355    return dict(list(base_config.items()) + list(config.items()))
356
357
358@keras_export('keras.layers.AdditiveAttention')
359class AdditiveAttention(BaseDenseAttention):
360  """Additive attention layer, a.k.a. Bahdanau-style attention.
361
362  Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
363  shape `[batch_size, Tv, dim]` and `key` tensor of shape
364  `[batch_size, Tv, dim]`. The calculation follows the steps:
365
366  1. Reshape `query` and `value` into shapes `[batch_size, Tq, 1, dim]`
367     and `[batch_size, 1, Tv, dim]` respectively.
368  2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear
369     sum: `scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)`
370  3. Use scores to calculate a distribution with shape
371     `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
372  4. Use `distribution` to create a linear combination of `value` with
373     shape `[batch_size, Tq, dim]`:
374     `return tf.matmul(distribution, value)`.
375
376  Args:
377    use_scale: If `True`, will create a variable to scale the attention scores.
378    causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such
379      that position `i` cannot attend to positions `j > i`. This prevents the
380      flow of information from the future towards the past.
381    dropout: Float between 0 and 1. Fraction of the units to drop for the
382      attention scores.
383
384  Call Args:
385
386    inputs: List of the following tensors:
387      * query: Query `Tensor` of shape `[batch_size, Tq, dim]`.
388      * value: Value `Tensor` of shape `[batch_size, Tv, dim]`.
389      * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not
390        given, will use `value` for both `key` and `value`, which is the
391        most common case.
392    mask: List of the following tensors:
393      * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`.
394        If given, the output will be zero at the positions where
395        `mask==False`.
396      * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
397        If given, will apply the mask such that values at positions where
398        `mask==False` do not contribute to the result.
399    training: Python boolean indicating whether the layer should behave in
400      training mode (adding dropout) or in inference mode (no dropout).
401    return_attention_scores: bool, it `True`, returns the attention scores
402      (after masking and softmax) as an additional output argument.
403
404  Output:
405
406    Attention outputs of shape `[batch_size, Tq, dim]`.
407    [Optional] Attention scores after masking and softmax with shape
408      `[batch_size, Tq, Tv]`.
409
410  The meaning of `query`, `value` and `key` depend on the application. In the
411  case of text similarity, for example, `query` is the sequence embeddings of
412  the first piece of text and `value` is the sequence embeddings of the second
413  piece of text. `key` is usually the same tensor as `value`.
414
415  Here is a code example for using `AdditiveAttention` in a CNN+Attention
416  network:
417
418  ```python
419  # Variable-length int sequences.
420  query_input = tf.keras.Input(shape=(None,), dtype='int32')
421  value_input = tf.keras.Input(shape=(None,), dtype='int32')
422
423  # Embedding lookup.
424  token_embedding = tf.keras.layers.Embedding(max_tokens, dimension)
425  # Query embeddings of shape [batch_size, Tq, dimension].
426  query_embeddings = token_embedding(query_input)
427  # Value embeddings of shape [batch_size, Tv, dimension].
428  value_embeddings = token_embedding(value_input)
429
430  # CNN layer.
431  cnn_layer = tf.keras.layers.Conv1D(
432      filters=100,
433      kernel_size=4,
434      # Use 'same' padding so outputs have the same shape as inputs.
435      padding='same')
436  # Query encoding of shape [batch_size, Tq, filters].
437  query_seq_encoding = cnn_layer(query_embeddings)
438  # Value encoding of shape [batch_size, Tv, filters].
439  value_seq_encoding = cnn_layer(value_embeddings)
440
441  # Query-value attention of shape [batch_size, Tq, filters].
442  query_value_attention_seq = tf.keras.layers.AdditiveAttention()(
443      [query_seq_encoding, value_seq_encoding])
444
445  # Reduce over the sequence axis to produce encodings of shape
446  # [batch_size, filters].
447  query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
448      query_seq_encoding)
449  query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
450      query_value_attention_seq)
451
452  # Concatenate query and document encodings to produce a DNN input layer.
453  input_layer = tf.keras.layers.Concatenate()(
454      [query_encoding, query_value_attention])
455
456  # Add DNN layers, and create Model.
457  # ...
458  ```
459  """
460
461  def __init__(self, use_scale=True, **kwargs):
462    super(AdditiveAttention, self).__init__(**kwargs)
463    self.use_scale = use_scale
464
465  def build(self, input_shape):
466    v_shape = tensor_shape.TensorShape(input_shape[1])
467    dim = v_shape[-1]
468    if isinstance(dim, tensor_shape.Dimension):
469      dim = dim.value
470    if self.use_scale:
471      self.scale = self.add_weight(
472          name='scale',
473          shape=[dim],
474          initializer=init_ops.glorot_uniform_initializer(),
475          dtype=self.dtype,
476          trainable=True)
477    else:
478      self.scale = None
479    super(AdditiveAttention, self).build(input_shape)
480
481  def _calculate_scores(self, query, key):
482    """Calculates attention scores as a nonlinear sum of query and key.
483
484    Args:
485      query: Query tensor of shape `[batch_size, Tq, dim]`.
486      key: Key tensor of shape `[batch_size, Tv, dim]`.
487    Returns:
488      Tensor of shape `[batch_size, Tq, Tv]`.
489    """
490    # Reshape tensors to enable broadcasting.
491    # Reshape into [batch_size, Tq, 1, dim].
492    q_reshaped = array_ops.expand_dims(query, axis=-2)
493    # Reshape into [batch_size, 1, Tv, dim].
494    k_reshaped = array_ops.expand_dims(key, axis=-3)
495    if self.use_scale:
496      scale = self.scale
497    else:
498      scale = 1.
499    return math_ops.reduce_sum(
500        scale * math_ops.tanh(q_reshaped + k_reshaped), axis=-1)
501
502  def get_config(self):
503    config = {'use_scale': self.use_scale}
504    base_config = super(AdditiveAttention, self).get_config()
505    return dict(list(base_config.items()) + list(config.items()))
506
507
508def _lower_triangular_mask(shape):
509  """Creates a lower-triangular boolean mask over the last 2 dimensions."""
510  row_index = math_ops.cumsum(
511      array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-2)
512  col_index = math_ops.cumsum(
513      array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-1)
514  return math_ops.greater_equal(row_index, col_index)
515
516
517def _merge_masks(x, y):
518  if x is None:
519    return y
520  if y is None:
521    return x
522  return math_ops.logical_and(x, y)
523