1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Attention layers that can be used in sequence DNN/CNN models. 16 17This file follows the terminology of https://arxiv.org/abs/1706.03762 Figure 2. 18Attention is formed by three tensors: Query, Key and Value. 19""" 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.keras import backend 25from tensorflow.python.keras.engine.base_layer import Layer 26from tensorflow.python.keras.utils import control_flow_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import init_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn 31from tensorflow.python.util.tf_export import keras_export 32 33 34class BaseDenseAttention(Layer): 35 """Base Attention class for Dense networks. 36 37 This class is suitable for Dense or CNN networks, and not for RNN networks. 38 39 Implementations of attention mechanisms should inherit from this class, and 40 reuse the `apply_attention_scores()` method. 41 42 Args: 43 causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such 44 that position `i` cannot attend to positions `j > i`. This prevents the 45 flow of information from the future towards the past. 46 dropout: Float between 0 and 1. Fraction of the units to drop for the 47 attention scores. 48 49 Call Args: 50 51 inputs: List of the following tensors: 52 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 53 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 54 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not 55 given, will use `value` for both `key` and `value`, which is the 56 most common case. 57 mask: List of the following tensors: 58 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 59 If given, the output will be zero at the positions where 60 `mask==False`. 61 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 62 If given, will apply the mask such that values at positions where 63 `mask==False` do not contribute to the result. 64 training: Python boolean indicating whether the layer should behave in 65 training mode (adding dropout) or in inference mode (no dropout). 66 return_attention_scores: bool, it `True`, returns the attention scores 67 (after masking and softmax) as an additional output argument. 68 69 Output: 70 71 Attention outputs of shape `[batch_size, Tq, dim]`. 72 [Optional] Attention scores after masking and softmax with shape 73 `[batch_size, Tq, Tv]`. 74 """ 75 76 def __init__(self, causal=False, dropout=0.0, 77 **kwargs): 78 super(BaseDenseAttention, self).__init__(**kwargs) 79 self.causal = causal 80 self.dropout = dropout 81 self.supports_masking = True 82 83 def _calculate_scores(self, query, key): 84 """Calculates attention scores. 85 86 Args: 87 query: Query tensor of shape `[batch_size, Tq, dim]`. 88 key: Key tensor of shape `[batch_size, Tv, dim]`. 89 90 Returns: 91 Tensor of shape `[batch_size, Tq, Tv]`. 92 """ 93 return NotImplementedError 94 95 def _apply_scores(self, scores, value, scores_mask=None, training=None): 96 """Applies attention scores to the given value tensor. 97 98 To use this method in your attention layer, follow the steps: 99 100 * Use `query` tensor of shape `[batch_size, Tq]` and `key` tensor of shape 101 `[batch_size, Tv]` to calculate the attention `scores`. 102 * Pass `scores` and `value` tensors to this method. The method applies 103 `scores_mask`, calculates `attention_distribution = softmax(scores)`, then 104 returns `matmul(attention_distribution, value). 105 * Apply `query_mask` and return the result. 106 107 Args: 108 scores: Scores float tensor of shape `[batch_size, Tq, Tv]`. 109 value: Value tensor of shape `[batch_size, Tv, dim]`. 110 scores_mask: A boolean mask `Tensor` of shape `[batch_size, 1, Tv]` or 111 `[batch_size, Tq, Tv]`. If given, scores at positions where 112 `scores_mask==False` do not contribute to the result. It must contain 113 at least one `True` value in each line along the last dimension. 114 training: Python boolean indicating whether the layer should behave in 115 training mode (adding dropout) or in inference mode (no dropout). 116 117 Returns: 118 Tensor of shape `[batch_size, Tq, dim]`. 119 Attention scores after masking and softmax with shape 120 `[batch_size, Tq, Tv]`. 121 """ 122 if scores_mask is not None: 123 padding_mask = math_ops.logical_not(scores_mask) 124 # Bias so padding positions do not contribute to attention distribution. 125 # Note 65504. is the max float16 value. 126 if scores.dtype is dtypes.float16: 127 scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype) 128 else: 129 scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype) 130 if training is None: 131 training = backend.learning_phase() 132 weights = nn.softmax(scores) 133 134 def dropped_weights(): 135 return nn.dropout(weights, rate=self.dropout) 136 137 weights = control_flow_util.smart_cond(training, dropped_weights, 138 lambda: array_ops.identity(weights)) 139 return math_ops.matmul(weights, value), weights 140 141 # TODO(b/125916026): Consider exposing a __call__ method with named args. 142 def call(self, 143 inputs, 144 mask=None, 145 training=None, 146 return_attention_scores=False): 147 self._validate_call_args(inputs=inputs, mask=mask) 148 q = inputs[0] 149 v = inputs[1] 150 k = inputs[2] if len(inputs) > 2 else v 151 q_mask = mask[0] if mask else None 152 v_mask = mask[1] if mask else None 153 scores = self._calculate_scores(query=q, key=k) 154 if v_mask is not None: 155 # Mask of shape [batch_size, 1, Tv]. 156 v_mask = array_ops.expand_dims(v_mask, axis=-2) 157 if self.causal: 158 # Creates a lower triangular mask, so position i cannot attend to 159 # positions j>i. This prevents the flow of information from the future 160 # into the past. 161 scores_shape = array_ops.shape(scores) 162 # causal_mask_shape = [1, Tq, Tv]. 163 causal_mask_shape = array_ops.concat( 164 [array_ops.ones_like(scores_shape[:-2]), scores_shape[-2:]], 165 axis=0) 166 causal_mask = _lower_triangular_mask(causal_mask_shape) 167 else: 168 causal_mask = None 169 scores_mask = _merge_masks(v_mask, causal_mask) 170 result, attention_scores = self._apply_scores( 171 scores=scores, value=v, scores_mask=scores_mask, training=training) 172 if q_mask is not None: 173 # Mask of shape [batch_size, Tq, 1]. 174 q_mask = array_ops.expand_dims(q_mask, axis=-1) 175 result *= math_ops.cast(q_mask, dtype=result.dtype) 176 if return_attention_scores: 177 return result, attention_scores 178 return result 179 180 def compute_mask(self, inputs, mask=None): 181 self._validate_call_args(inputs=inputs, mask=mask) 182 if mask: 183 q_mask = mask[0] 184 if q_mask is None: 185 return None 186 return ops.convert_to_tensor_v2_with_dispatch(q_mask) 187 return None 188 189 def _validate_call_args(self, inputs, mask): 190 """Validates arguments of the call method.""" 191 class_name = self.__class__.__name__ 192 if not isinstance(inputs, list): 193 raise ValueError( 194 '{} layer must be called on a list of inputs, namely [query, value] ' 195 'or [query, value, key].'.format(class_name)) 196 if len(inputs) < 2 or len(inputs) > 3: 197 raise ValueError( 198 '{} layer accepts inputs list of length 2 or 3, ' 199 'namely [query, value] or [query, value, key]. ' 200 'Given length: {}'.format(class_name, len(inputs))) 201 if mask: 202 if not isinstance(mask, list): 203 raise ValueError( 204 '{} layer mask must be a list, ' 205 'namely [query_mask, value_mask].'.format(class_name)) 206 if len(mask) < 2 or len(mask) > len(inputs): 207 raise ValueError( 208 '{} layer mask must be a list of length 2, namely [query_mask, ' 209 'value_mask]. Given length: {}'.format(class_name, len(mask))) 210 211 def get_config(self): 212 config = { 213 'causal': self.causal, 214 'dropout': self.dropout, 215 } 216 base_config = super(BaseDenseAttention, self).get_config() 217 return dict(list(base_config.items()) + list(config.items())) 218 219 220@keras_export('keras.layers.Attention') 221class Attention(BaseDenseAttention): 222 """Dot-product attention layer, a.k.a. Luong-style attention. 223 224 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of 225 shape `[batch_size, Tv, dim]` and `key` tensor of shape 226 `[batch_size, Tv, dim]`. The calculation follows the steps: 227 228 1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot 229 product: `scores = tf.matmul(query, key, transpose_b=True)`. 230 2. Use scores to calculate a distribution with shape 231 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. 232 3. Use `distribution` to create a linear combination of `value` with 233 shape `[batch_size, Tq, dim]`: 234 `return tf.matmul(distribution, value)`. 235 236 Args: 237 use_scale: If `True`, will create a scalar variable to scale the attention 238 scores. 239 causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such 240 that position `i` cannot attend to positions `j > i`. This prevents the 241 flow of information from the future towards the past. 242 dropout: Float between 0 and 1. Fraction of the units to drop for the 243 attention scores. 244 245 Call Args: 246 247 inputs: List of the following tensors: 248 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 249 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 250 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not 251 given, will use `value` for both `key` and `value`, which is the 252 most common case. 253 mask: List of the following tensors: 254 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 255 If given, the output will be zero at the positions where 256 `mask==False`. 257 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 258 If given, will apply the mask such that values at positions where 259 `mask==False` do not contribute to the result. 260 return_attention_scores: bool, it `True`, returns the attention scores 261 (after masking and softmax) as an additional output argument. 262 training: Python boolean indicating whether the layer should behave in 263 training mode (adding dropout) or in inference mode (no dropout). 264 265 Output: 266 267 Attention outputs of shape `[batch_size, Tq, dim]`. 268 [Optional] Attention scores after masking and softmax with shape 269 `[batch_size, Tq, Tv]`. 270 271 The meaning of `query`, `value` and `key` depend on the application. In the 272 case of text similarity, for example, `query` is the sequence embeddings of 273 the first piece of text and `value` is the sequence embeddings of the second 274 piece of text. `key` is usually the same tensor as `value`. 275 276 Here is a code example for using `Attention` in a CNN+Attention network: 277 278 ```python 279 # Variable-length int sequences. 280 query_input = tf.keras.Input(shape=(None,), dtype='int32') 281 value_input = tf.keras.Input(shape=(None,), dtype='int32') 282 283 # Embedding lookup. 284 token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64) 285 # Query embeddings of shape [batch_size, Tq, dimension]. 286 query_embeddings = token_embedding(query_input) 287 # Value embeddings of shape [batch_size, Tv, dimension]. 288 value_embeddings = token_embedding(value_input) 289 290 # CNN layer. 291 cnn_layer = tf.keras.layers.Conv1D( 292 filters=100, 293 kernel_size=4, 294 # Use 'same' padding so outputs have the same shape as inputs. 295 padding='same') 296 # Query encoding of shape [batch_size, Tq, filters]. 297 query_seq_encoding = cnn_layer(query_embeddings) 298 # Value encoding of shape [batch_size, Tv, filters]. 299 value_seq_encoding = cnn_layer(value_embeddings) 300 301 # Query-value attention of shape [batch_size, Tq, filters]. 302 query_value_attention_seq = tf.keras.layers.Attention()( 303 [query_seq_encoding, value_seq_encoding]) 304 305 # Reduce over the sequence axis to produce encodings of shape 306 # [batch_size, filters]. 307 query_encoding = tf.keras.layers.GlobalAveragePooling1D()( 308 query_seq_encoding) 309 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()( 310 query_value_attention_seq) 311 312 # Concatenate query and document encodings to produce a DNN input layer. 313 input_layer = tf.keras.layers.Concatenate()( 314 [query_encoding, query_value_attention]) 315 316 # Add DNN layers, and create Model. 317 # ... 318 ``` 319 """ 320 321 def __init__(self, use_scale=False, **kwargs): 322 super(Attention, self).__init__(**kwargs) 323 self.use_scale = use_scale 324 325 def build(self, input_shape): 326 """Creates scale variable if use_scale==True.""" 327 if self.use_scale: 328 self.scale = self.add_weight( 329 name='scale', 330 shape=(), 331 initializer=init_ops.ones_initializer(), 332 dtype=self.dtype, 333 trainable=True) 334 else: 335 self.scale = None 336 super(Attention, self).build(input_shape) 337 338 def _calculate_scores(self, query, key): 339 """Calculates attention scores as a query-key dot product. 340 341 Args: 342 query: Query tensor of shape `[batch_size, Tq, dim]`. 343 key: Key tensor of shape `[batch_size, Tv, dim]`. 344 Returns: 345 Tensor of shape `[batch_size, Tq, Tv]`. 346 """ 347 scores = math_ops.matmul(query, key, transpose_b=True) 348 if self.scale is not None: 349 scores *= self.scale 350 return scores 351 352 def get_config(self): 353 config = {'use_scale': self.use_scale} 354 base_config = super(Attention, self).get_config() 355 return dict(list(base_config.items()) + list(config.items())) 356 357 358@keras_export('keras.layers.AdditiveAttention') 359class AdditiveAttention(BaseDenseAttention): 360 """Additive attention layer, a.k.a. Bahdanau-style attention. 361 362 Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of 363 shape `[batch_size, Tv, dim]` and `key` tensor of shape 364 `[batch_size, Tv, dim]`. The calculation follows the steps: 365 366 1. Reshape `query` and `value` into shapes `[batch_size, Tq, 1, dim]` 367 and `[batch_size, 1, Tv, dim]` respectively. 368 2. Calculate scores with shape `[batch_size, Tq, Tv]` as a non-linear 369 sum: `scores = tf.reduce_sum(tf.tanh(query + value), axis=-1)` 370 3. Use scores to calculate a distribution with shape 371 `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`. 372 4. Use `distribution` to create a linear combination of `value` with 373 shape `[batch_size, Tq, dim]`: 374 `return tf.matmul(distribution, value)`. 375 376 Args: 377 use_scale: If `True`, will create a variable to scale the attention scores. 378 causal: Boolean. Set to `True` for decoder self-attention. Adds a mask such 379 that position `i` cannot attend to positions `j > i`. This prevents the 380 flow of information from the future towards the past. 381 dropout: Float between 0 and 1. Fraction of the units to drop for the 382 attention scores. 383 384 Call Args: 385 386 inputs: List of the following tensors: 387 * query: Query `Tensor` of shape `[batch_size, Tq, dim]`. 388 * value: Value `Tensor` of shape `[batch_size, Tv, dim]`. 389 * key: Optional key `Tensor` of shape `[batch_size, Tv, dim]`. If not 390 given, will use `value` for both `key` and `value`, which is the 391 most common case. 392 mask: List of the following tensors: 393 * query_mask: A boolean mask `Tensor` of shape `[batch_size, Tq]`. 394 If given, the output will be zero at the positions where 395 `mask==False`. 396 * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. 397 If given, will apply the mask such that values at positions where 398 `mask==False` do not contribute to the result. 399 training: Python boolean indicating whether the layer should behave in 400 training mode (adding dropout) or in inference mode (no dropout). 401 return_attention_scores: bool, it `True`, returns the attention scores 402 (after masking and softmax) as an additional output argument. 403 404 Output: 405 406 Attention outputs of shape `[batch_size, Tq, dim]`. 407 [Optional] Attention scores after masking and softmax with shape 408 `[batch_size, Tq, Tv]`. 409 410 The meaning of `query`, `value` and `key` depend on the application. In the 411 case of text similarity, for example, `query` is the sequence embeddings of 412 the first piece of text and `value` is the sequence embeddings of the second 413 piece of text. `key` is usually the same tensor as `value`. 414 415 Here is a code example for using `AdditiveAttention` in a CNN+Attention 416 network: 417 418 ```python 419 # Variable-length int sequences. 420 query_input = tf.keras.Input(shape=(None,), dtype='int32') 421 value_input = tf.keras.Input(shape=(None,), dtype='int32') 422 423 # Embedding lookup. 424 token_embedding = tf.keras.layers.Embedding(max_tokens, dimension) 425 # Query embeddings of shape [batch_size, Tq, dimension]. 426 query_embeddings = token_embedding(query_input) 427 # Value embeddings of shape [batch_size, Tv, dimension]. 428 value_embeddings = token_embedding(value_input) 429 430 # CNN layer. 431 cnn_layer = tf.keras.layers.Conv1D( 432 filters=100, 433 kernel_size=4, 434 # Use 'same' padding so outputs have the same shape as inputs. 435 padding='same') 436 # Query encoding of shape [batch_size, Tq, filters]. 437 query_seq_encoding = cnn_layer(query_embeddings) 438 # Value encoding of shape [batch_size, Tv, filters]. 439 value_seq_encoding = cnn_layer(value_embeddings) 440 441 # Query-value attention of shape [batch_size, Tq, filters]. 442 query_value_attention_seq = tf.keras.layers.AdditiveAttention()( 443 [query_seq_encoding, value_seq_encoding]) 444 445 # Reduce over the sequence axis to produce encodings of shape 446 # [batch_size, filters]. 447 query_encoding = tf.keras.layers.GlobalAveragePooling1D()( 448 query_seq_encoding) 449 query_value_attention = tf.keras.layers.GlobalAveragePooling1D()( 450 query_value_attention_seq) 451 452 # Concatenate query and document encodings to produce a DNN input layer. 453 input_layer = tf.keras.layers.Concatenate()( 454 [query_encoding, query_value_attention]) 455 456 # Add DNN layers, and create Model. 457 # ... 458 ``` 459 """ 460 461 def __init__(self, use_scale=True, **kwargs): 462 super(AdditiveAttention, self).__init__(**kwargs) 463 self.use_scale = use_scale 464 465 def build(self, input_shape): 466 v_shape = tensor_shape.TensorShape(input_shape[1]) 467 dim = v_shape[-1] 468 if isinstance(dim, tensor_shape.Dimension): 469 dim = dim.value 470 if self.use_scale: 471 self.scale = self.add_weight( 472 name='scale', 473 shape=[dim], 474 initializer=init_ops.glorot_uniform_initializer(), 475 dtype=self.dtype, 476 trainable=True) 477 else: 478 self.scale = None 479 super(AdditiveAttention, self).build(input_shape) 480 481 def _calculate_scores(self, query, key): 482 """Calculates attention scores as a nonlinear sum of query and key. 483 484 Args: 485 query: Query tensor of shape `[batch_size, Tq, dim]`. 486 key: Key tensor of shape `[batch_size, Tv, dim]`. 487 Returns: 488 Tensor of shape `[batch_size, Tq, Tv]`. 489 """ 490 # Reshape tensors to enable broadcasting. 491 # Reshape into [batch_size, Tq, 1, dim]. 492 q_reshaped = array_ops.expand_dims(query, axis=-2) 493 # Reshape into [batch_size, 1, Tv, dim]. 494 k_reshaped = array_ops.expand_dims(key, axis=-3) 495 if self.use_scale: 496 scale = self.scale 497 else: 498 scale = 1. 499 return math_ops.reduce_sum( 500 scale * math_ops.tanh(q_reshaped + k_reshaped), axis=-1) 501 502 def get_config(self): 503 config = {'use_scale': self.use_scale} 504 base_config = super(AdditiveAttention, self).get_config() 505 return dict(list(base_config.items()) + list(config.items())) 506 507 508def _lower_triangular_mask(shape): 509 """Creates a lower-triangular boolean mask over the last 2 dimensions.""" 510 row_index = math_ops.cumsum( 511 array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-2) 512 col_index = math_ops.cumsum( 513 array_ops.ones(shape=shape, dtype=dtypes.int32), axis=-1) 514 return math_ops.greater_equal(row_index, col_index) 515 516 517def _merge_masks(x, y): 518 if x is None: 519 return y 520 if y is None: 521 return x 522 return math_ops.logical_and(x, y) 523