• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests dense attention layers."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python import keras
21from tensorflow.python.eager import context
22from tensorflow.python.keras import combinations
23from tensorflow.python.keras import testing_utils
24from tensorflow.python.keras.layers import core
25from tensorflow.python.keras.layers import dense_attention
26from tensorflow.python.keras.mixed_precision import policy
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.platform import test
31
32
33@combinations.generate(combinations.combine(mode=['graph', 'eager']))
34class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
35
36  def test_one_dim_with_mask(self):
37    # Scores tensor of shape [1, 1, 1]
38    scores = np.array([[[1.1]]], dtype=np.float32)
39    # Value tensor of shape [1, 1, 1]
40    v = np.array([[[1.6]]], dtype=np.float32)
41    # Scores mask tensor of shape [1, 1, 1]
42    scores_mask = np.array([[[True]]], dtype=np.bool_)
43    actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
44        scores=scores, value=v, scores_mask=scores_mask)
45
46    # Expected softmax_scores = [[[1]]]
47    expected_scores = np.array([[[1.]]], dtype=np.float32)
48    self.assertAllClose(expected_scores, actual_scores)
49    # Expected tensor of shape [1, 1, 1].
50    # expected000 = softmax_scores[0, 0] * 1.6 = 1.6
51    expected = np.array([[[1.6]]], dtype=np.float32)
52    self.assertAllClose(expected, actual)
53
54  def test_one_dim_no_mask(self):
55    # Scores tensor of shape [1, 1, 1]
56    scores = np.array([[[1.1]]], dtype=np.float32)
57    # Value tensor of shape [1, 1, 1]
58    v = np.array([[[1.6]]], dtype=np.float32)
59    actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
60        scores=scores, value=v)
61
62    # Expected softmax_scores = [[[1]]]
63    expected_scores = np.array([[[1.]]], dtype=np.float32)
64    self.assertAllClose(expected_scores, actual_scores)
65    # Expected tensor of shape [1, 1, 1].
66    # expected000 = softmax_scores[0, 0] * 1.6 = 1.6
67    expected = np.array([[[1.6]]], dtype=np.float32)
68    self.assertAllClose(expected, actual)
69
70  def test_multi_dim_with_mask(self):
71    # Scores tensor of shape [1, 1, 3]
72    scores = np.array([[[1., 0., 1.]]], dtype=np.float32)
73    # Value tensor of shape [1, 3, 1]
74    v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
75    # Scores mask tensor of shape [1, 1, 3]
76    scores_mask = np.array([[[True, True, False]]], dtype=np.bool_)
77    actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
78        scores=scores, value=v, scores_mask=scores_mask)
79
80    # Expected softmax scores = softmax(scores) with zeros in positions where
81    # v_mask == False.
82    # => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
83    #    softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
84    #    softmax_scores002 = 0
85    expected_scores = np.array([[[0.73105857863, 0.26894142137, 0.]]],
86                               dtype=np.float32)
87    self.assertAllClose(expected_scores, actual_scores)
88    # Expected tensor of shape [1, 1, 1].
89    # expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8
90    #             = 1.35795272077
91    expected = np.array([[[1.35795272077]]], dtype=np.float32)
92    self.assertAllClose(expected, actual)
93
94  def test_multi_dim_no_mask(self):
95    # Scores tensor of shape [1, 1, 3]
96    scores = np.array([[[1., 0., 1.]]], dtype=np.float32)
97    # Value tensor of shape [1, 3, 1]
98    v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
99    actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
100        scores=scores, value=v)
101
102    # Expected softmax_scores = softmax(scores).
103    # => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1))
104    #                      = 0.42231879825
105    #    softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1))
106    #                      = 0.15536240349
107    #    softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1))
108    #                      = 0.42231879825
109    expected_scores = np.array(
110        [[[0.42231879825, 0.15536240349, 0.42231879825]]], dtype=np.float32)
111    self.assertAllClose(expected_scores, actual_scores)
112    # Expected tensor of shape [1, 1, 1].
113    # expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7
114    #               - 0.42231879825 * 0.8
115    #             = 0.44660872104
116    expected = np.array([[[0.44660872104]]], dtype=np.float32)
117    self.assertAllClose(expected, actual)
118
119  def test_one_dim_batch_size_two(self):
120    # Scores tensor of shape [2, 1, 1]
121    scores = np.array([[[1.1]], [[2.1]]], dtype=np.float32)
122    # Value tensor of shape [2, 1, 1]
123    v = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
124    # Scpres mask tensor of shape [2, 1, 1]
125    scores_mask = np.array([[[True]], [[True]]], dtype=np.bool_)
126    actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
127        scores=scores, value=v, scores_mask=scores_mask)
128
129    # Expected softmax_scores = [[[1]], [[1]]]
130    expected_scores = np.array([[[1.]], [[1.]]], dtype=np.float32)
131    self.assertAllClose(expected_scores, actual_scores)
132    # Expected tensor of shape [2, 1, 1].
133    # expected000 = softmax_scores[0, 0] * 1.6 = 1.6
134    # expected100 = softmax_scores[1, 0] * 2.6 = 2.6
135    expected = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
136    self.assertAllClose(expected, actual)
137
138  def test_shape_with_dropout(self):
139    # scores: Scores float tensor of shape `[batch_size, tq, tv]`.
140    # value: Value tensor of shape `[batch_size, tv, dim]`.
141    batch_size = 4
142    tq = 5
143    tv = 6
144    dim = 7
145    scores = np.ones((batch_size, tq, tv))
146    value = np.ones((batch_size, tv, dim))
147    actual, actual_scores = dense_attention.BaseDenseAttention(
148        dropout=0.1)._apply_scores(
149            scores=scores, value=value, training=False)
150
151    # Expected Tensor of shape `[batch_size, tq, tv]`.
152    expected_scores_shape = [batch_size, tq, tv]
153    self.assertAllEqual(expected_scores_shape, array_ops.shape(actual_scores))
154    # Expected Tensor of shape `[batch_size, tq, dim]`.
155    expected_shape = [batch_size, tq, dim]
156    self.assertAllEqual(expected_shape, array_ops.shape(actual))
157
158  def test_serialization(self):
159    # Test serialization with causal
160    layer = dense_attention.BaseDenseAttention(causal=True)
161
162    config = keras.layers.serialize(layer)
163    new_layer = keras.layers.deserialize(config)
164    self.assertEqual(new_layer.causal, True)
165
166    config = layer.get_config()
167    new_layer = dense_attention.BaseDenseAttention.from_config(config)
168    self.assertEqual(new_layer.causal, True)
169
170
171@combinations.generate(combinations.combine(mode=['graph', 'eager']))
172class AttentionTest(test.TestCase, parameterized.TestCase):
173
174  def test_calculate_scores_one_dim(self):
175    # Query tensor of shape [1, 1, 1]
176    q = np.array([[[1.1]]], dtype=np.float32)
177    # Key tensor of shape [1, 1, 1]
178    k = np.array([[[1.6]]], dtype=np.float32)
179    attention_layer = dense_attention.Attention()
180    attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
181    actual = attention_layer._calculate_scores(query=q, key=k)
182
183    # Expected tensor of shape [1, 1, 1].
184    # expected000 = 1.1*1.6 = 1.76
185    expected = np.array([[[1.76]]], dtype=np.float32)
186    self.assertAllClose(expected, actual)
187
188  def test_calculate_scores_multi_dim(self):
189    # Query tensor of shape [1, 2, 4]
190    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
191    # Key tensor of shape [1, 3, 4]
192    k = np.array(
193        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
194        dtype=np.float32)
195    attention_layer = dense_attention.Attention()
196    attention_layer.build(input_shape=([1, 2, 4], [1, 3, 4]))
197    actual = attention_layer._calculate_scores(query=q, key=k)
198
199    # Expected tensor of shape [1, 2, 3].
200    # expected000 = 1.*1.5+1.1*1.6+1.2*1.7+1.3*1.8 = 7.64
201    # expected001 = 1.*2.5+1.1*2.6+1.2*2.7+1.3*2.8 = 12.24
202    # expected002 = 1.*3.5+1.1*3.6+1.2*3.7+1.3*3.8 = 16.84
203    # expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24
204    # expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84
205    # expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44
206    expected = np.array([[[7.64, 12.24, 16.84], [14.24, 22.84, 31.44]]],
207                        dtype=np.float32)
208    self.assertAllClose(expected, actual)
209
210  def test_calculate_scores_one_dim_batch_size_two(self):
211    # Query tensor of shape [2, 1, 1]
212    q = np.array([[[1.1]], [[2.1]]], dtype=np.float32)
213    # Key tensor of shape [2, 1, 1]
214    k = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
215    attention_layer = dense_attention.Attention()
216    attention_layer.build(input_shape=([2, 1, 1], [2, 1, 1]))
217    actual = attention_layer._calculate_scores(query=q, key=k)
218
219    # Expected tensor of shape [2, 1, 1].
220    # expected000 = 1.1*1.6 = 1.76
221    # expected100 = 2.1*2.6 = 5.46
222    expected = np.array([[[1.76]], [[5.46]]], dtype=np.float32)
223    self.assertAllClose(expected, actual)
224
225  def test_calculate_scores_one_dim_with_scale(self):
226    """Tests that scores are multiplied by scale."""
227    # Query tensor of shape [1, 1, 1]
228    q = np.array([[[1.1]]], dtype=np.float32)
229    # Key tensor of shape [1, 1, 1]
230    k = np.array([[[1.6]]], dtype=np.float32)
231    attention_layer = dense_attention.Attention(use_scale=True)
232    attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
233    attention_layer.scale = -2.
234    actual = attention_layer._calculate_scores(query=q, key=k)
235
236    # Expected tensor of shape [1, 1, 1].
237    # expected000 = -2*1.1*1.6 = -3.52
238    expected = np.array([[[-3.52]]], dtype=np.float32)
239    self.assertAllClose(expected, actual)
240
241  def test_shape(self):
242    # Query tensor of shape [1, 2, 4]
243    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
244    # Value tensor of shape [1, 3, 4]
245    v = np.array(
246        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
247        dtype=np.float32)
248    # Value mask tensor of shape [1, 3]
249    v_mask = np.array([[True, True, False]], dtype=np.bool_)
250    attention_layer = dense_attention.Attention()
251    actual = attention_layer([q, v], mask=[None, v_mask])
252
253    expected_shape = [1, 2, 4]
254    self.assertAllEqual(expected_shape, array_ops.shape(actual))
255
256  def test_shape_with_key(self):
257    # Query tensor of shape [1, 2, 4]
258    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
259    # Value tensor of shape [1, 3, 4]
260    v = np.array(
261        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
262        dtype=np.float32)
263    # Key tensor of shape [1, 3, 4]
264    k = np.array(
265        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
266        dtype=np.float32)
267    # Value mask tensor of shape [1, 3]
268    v_mask = np.array([[True, True, False]], dtype=np.bool_)
269    attention_layer = dense_attention.Attention()
270    actual = attention_layer([q, v, k], mask=[None, v_mask])
271
272    expected_shape = [1, 2, 4]
273    self.assertAllEqual(expected_shape, array_ops.shape(actual))
274
275  def test_multi_dim(self):
276    # Query tensor of shape [1, 1, 1]
277    q = np.array([[[1.1]]], dtype=np.float32)
278    # Value tensor of shape [1, 3, 1]
279    v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
280    # Value mask tensor of shape [1, 3]
281    v_mask = np.array([[True, True, False]], dtype=np.bool_)
282    attention_layer = dense_attention.Attention()
283    actual = attention_layer([q, v], mask=[None, v_mask])
284
285    # Expected scores of shape [1, 1, 3]
286    # scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8]]] = [[[1.76, 0.77, -0.88]]]
287    # Expected attention distribution = softmax(scores) with zeros in
288    # positions where v_mask == False.
289    # => attention_distribution000 = exp(1.76)/(exp(1.76) + exp(0.77))
290    #                              = 0.72908792234
291    #    attention_distribution001 = exp(0.77)/(exp(1.76) + exp(0.77))
292    #                              = 0.27091207765
293    #    attention_distribution002 = 0
294    #
295    # Expected tensor of shape [1, 1, 1].
296    # expected000 = 0.72908792234 * 1.6 + 0.27091207765 * 0.7 - 0 * 0.8
297    #             = 1.3561791301
298    expected = np.array([[[1.3561791301]]], dtype=np.float32)
299    self.assertAllClose(expected, actual)
300
301  def test_multi_dim_with_key(self):
302    # Query tensor of shape [1, 1, 1]
303    q = np.array([[[1.1]]], dtype=np.float32)
304    # Value tensor of shape [1, 3, 1]
305    v = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32)
306    # Key tensor of shape [1, 3, 1]
307    k = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
308    # Value mask tensor of shape [1, 3]
309    v_mask = np.array([[True, True, False]], dtype=np.bool_)
310    attention_layer = dense_attention.Attention()
311    actual = attention_layer([q, v, k], mask=[None, v_mask])
312
313    # Expected scores of shape [1, 1, 3]
314    # scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8]]] = [[[1.76, 0.77, -0.88]]]
315    # Expected attention distribution = softmax(scores) with zeros in
316    # positions where v_mask == False.
317    # => attention_distribution000 = exp(1.76)/(exp(1.76) + exp(0.77))
318    #                              = 0.72908792234
319    #    attention_distribution001 = exp(0.77)/(exp(1.76) + exp(0.77))
320    #                              = 0.27091207765
321    #    attention_distribution002 = 0
322    #
323    # Expected tensor of shape [1, 1, 1].
324    # expected000 = 0.72908792234 * 0.5 + 0.27091207765 * 0.8 - 0 * 0.3
325    #             = 0.58127362329
326    expected = np.array([[[0.58127362329]]], dtype=np.float32)
327    self.assertAllClose(expected, actual)
328
329  @parameterized.named_parameters(
330      ('', False),
331      ('return_attention_scores', True),
332  )
333  def test_multi_dim_with_query_mask(self, return_attention_scores):
334    # Query tensor of shape [1, 2, 1]
335    q = np.array([[[1.1], [-0.5]]], dtype=np.float32)
336    # Value tensor of shape [1, 3, 1]
337    v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
338    # Query mask tensor of shape [1, 2]
339    q_mask = np.array([[True, False]], dtype=np.bool_)
340    # Value mask tensor of shape [1, 3]
341    v_mask = np.array([[True, True, False]], dtype=np.bool_)
342    attention_layer = dense_attention.Attention()
343    if return_attention_scores:
344      actual, actual_scores = attention_layer(
345          [q, v],
346          mask=[q_mask, v_mask],
347          return_attention_scores=return_attention_scores)
348    else:
349      actual = attention_layer([q, v],
350                               mask=[q_mask, v_mask],
351                               return_attention_scores=return_attention_scores)
352
353    # Expected scores of shape [1, 2, 3]
354    # scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8], [-0.5*1.6, -0.5*0.7, 0.5*0.8]]]
355    #        = [[[1.76, 0.77, -0.88], [-0.8, -0.35, 0.4]]]
356    # Expected attention distribution = softmax(scores) with zeros in
357    # positions where v_mask == False.
358    # => attention_distribution000 = exp(1.76)/(exp(1.76) + exp(0.77))
359    #                              = 0.72908792234
360    #    attention_distribution001 = exp(0.77)/(exp(1.76) + exp(0.77))
361    #                              = 0.27091207765
362    #    attention_distribution002 = 0
363    # => attention_distribution010 = exp(-0.8)/(exp(-0.8) + exp(-0.35))
364    #                              = 0.38936076605
365    #    attention_distribution011 = exp(-0.35)/(exp(-0.8) + exp(-0.35))
366    #                              = 0.61063923394
367    #    attention_distribution012 = 0
368    if return_attention_scores:
369      expected_scores = np.array([[[0.72908792234, 0.27091207765, 0.],
370                                   [0.38936076605, 0.61063923394, 0.]]],
371                                 dtype=np.float32)
372      self.assertAllClose(expected_scores, actual_scores)
373    # Expected tensor of shape [1, 2, 1] with zeros where  q_mask == False.
374    # expected000 = 0.72908792234 * 1.6 + 0.27091207765 * 0.7 - 0 * 0.8
375    #             = 1.3561791301
376    # expected000 = 0
377    expected = np.array([[[1.3561791301], [0.]]], dtype=np.float32)
378    self.assertAllClose(expected, actual)
379
380  def test_scale_None(self):
381    """Tests that scale is None by default."""
382    attention_layer = dense_attention.Attention()
383    attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
384    self.assertIsNone(attention_layer.scale)
385
386  def test_scale_init_eager(self):
387    """Tests that scale initializes to 1 when use_scale=True."""
388    if not context.executing_eagerly():
389      self.skipTest('Only run in eager mode')
390    attention_layer = dense_attention.Attention(use_scale=True)
391    attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
392    self.assertAllClose(1., attention_layer.scale.value())
393
394  def test_scale_init_graph(self):
395    """Tests that scale initializes to 1 when use_scale=True."""
396    with self.cached_session() as sess:
397      attention_layer = dense_attention.Attention(use_scale=True)
398      attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
399      sess.run(attention_layer.scale.initializer)
400      self.assertAllClose(1., attention_layer.scale.value())
401
402  @parameterized.named_parameters(
403      ('', False),
404      ('return_attention_scores', True),
405  )
406  def test_self_attention_causal(self, return_attention_scores):
407    # Query-value tensor of shape [1, 3, 1]
408    q = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32)
409    attention_layer = dense_attention.Attention(causal=True)
410    if return_attention_scores:
411      actual, actual_scores = attention_layer(
412          [q, q], return_attention_scores=return_attention_scores)
413    else:
414      actual = attention_layer([q, q],
415                               return_attention_scores=return_attention_scores)
416
417    # Expected scores of shape [1, 3, 3]
418    # scores = [[0.25, 0.4, -0.15], [0.4, 0.64, -0.24], [-0.15, -0.24, 0.09]]
419    # Expected attention distribution = softmax(scores) lower triangular
420    # => attention_distribution00 = [1., 0., 0.]
421    #    attention_distribution01
422    #      = [exp(0.4), exp(0.64), 0.] / (exp(0.4) + exp(0.64))
423    #      = [0.44028635073, 0.55971364926, 0.]
424    #    attention_distribution02
425    #      = [exp(-0.15), exp(-0.24), exp(0.09)]
426    #        / (exp(-0.15) + exp(-0.24) + exp(0.09))
427    #      = [0.31395396638, 0.28693232061, 0.399113713]
428    if return_attention_scores:
429      expected_scores = np.array(
430          [[[1., 0., 0.], [0.44028635073, 0.55971364926, 0.],
431            [0.31395396638, 0.28693232061, 0.399113713]]],
432          dtype=np.float32)
433      self.assertAllClose(expected_scores, actual_scores)
434    # Expected tensor of shape [1, 3, 1].
435    # expected000 = 0.5
436    # expected010 = 0.44028635073 * 0.5 + 0.55971364926 * 0.8
437    #             = 0.66791409477
438    # expected020 = 0.31395396638 * 0.5 +0.28693232061 * 0.8 -0.399113713 * 0.3
439    #             = 0.26678872577
440    expected = np.array([[[0.5], [0.66791409477], [0.26678872577]]],
441                        dtype=np.float32)
442    self.assertAllClose(expected, actual)
443
444  def test_inputs_not_list(self):
445    attention_layer = dense_attention.Attention()
446    q = np.array([[[1.1]]], dtype=np.float32)
447    with self.assertRaisesRegex(
448        ValueError, 'Attention layer must be called on a list of inputs'):
449      attention_layer(q)
450
451  def test_inputs_too_short(self):
452    attention_layer = dense_attention.Attention()
453    q = np.array([[[1.1]]], dtype=np.float32)
454    with self.assertRaisesRegex(
455        ValueError, 'Attention layer accepts inputs list of length 2 or 3'):
456      attention_layer([q])
457
458  def test_inputs_too_long(self):
459    attention_layer = dense_attention.Attention()
460    q = np.array([[[1.1]]], dtype=np.float32)
461    with self.assertRaisesRegex(
462        ValueError, 'Attention layer accepts inputs list of length 2 or 3'):
463      attention_layer([q, q, q, q])
464
465  def test_mask_not_list(self):
466    attention_layer = dense_attention.Attention()
467    q = np.array([[[1.1]]], dtype=np.float32)
468    mask = np.array([[True]], dtype=np.bool_)
469    with self.assertRaisesRegex(ValueError,
470                                'Attention layer mask must be a list'):
471      attention_layer([q, q], mask=mask)
472
473  def test_mask_too_short(self):
474    attention_layer = dense_attention.Attention()
475    q = np.array([[[1.1]]], dtype=np.float32)
476    mask = np.array([[True]], dtype=np.bool_)
477    with self.assertRaisesRegex(
478        ValueError, 'Attention layer mask must be a list of length 2'):
479      attention_layer([q, q], mask=[mask])
480
481  def test_mask_too_long(self):
482    attention_layer = dense_attention.Attention()
483    q = np.array([[[1.1]]], dtype=np.float32)
484    mask = np.array([[True]], dtype=np.bool_)
485    with self.assertRaisesRegex(
486        ValueError, 'Attention layer mask must be a list of length 2'):
487      attention_layer([q, q], mask=[mask, mask, mask])
488
489  def test_override_mask(self):
490    attention_layer = dense_attention.Attention()
491    q = core.Masking()(np.array([[[1.1]]], dtype=np.float32))
492    mask = np.array([[False]], dtype=np.bool_)
493    actual = attention_layer([q, q], mask=[mask, mask])
494    self.assertAllClose([[[0]]], actual)
495
496  def test_implicit_mask(self):
497    attention_layer = dense_attention.Attention()
498    q = core.Masking(1.1)(np.array([[[1.1], [1]]], dtype=np.float32))
499    v = core.Masking(1.2)(np.array([[[1.2], [1]]], dtype=np.float32))
500    actual = attention_layer([q, v])
501    self.assertAllClose([[[0], [1]]], actual)
502
503  @parameterized.named_parameters(
504      ('', False),
505      ('use_scale', True),
506  )
507  def test_serialization(self, use_scale):
508    # Test serialization with use_scale
509    layer = dense_attention.Attention(use_scale=use_scale)
510
511    config = keras.layers.serialize(layer)
512    new_layer = keras.layers.deserialize(config)
513    self.assertEqual(new_layer.use_scale, use_scale)
514
515    config = layer.get_config()
516    new_layer = dense_attention.Attention.from_config(config)
517    self.assertEqual(new_layer.use_scale, use_scale)
518
519
520@combinations.generate(combinations.combine(mode=['graph', 'eager']))
521class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
522
523  def test_calculate_scores_one_dim(self):
524    # Query tensor of shape [1, 1, 1]
525    q = np.array([[[1.1]]], dtype=np.float32)
526    # Key tensor of shape [1, 1, 1]
527    k = np.array([[[1.6]]], dtype=np.float32)
528    attention_layer = dense_attention.AdditiveAttention()
529    attention_layer.build(input_shape=([1, 1, 1], [1, 1, 1]))
530    # Scale tensor of shape [1]
531    attention_layer.scale = np.array([[[0.5]]], dtype=np.float32)
532    actual = attention_layer._calculate_scores(query=q, key=k)
533
534    # Expected tensor of shape [1, 1, 1].
535    # expected000 = 0.5 * tanh(1.1 + 1.6) = 0.49550372683
536    expected = np.array([[[0.49550372683]]], dtype=np.float32)
537    self.assertAllClose(expected, actual)
538
539  def test_calculate_scores_multi_dim(self):
540    # Query tensor of shape [1, 2, 4]
541    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
542    # Key tensor of shape [1, 3, 4]
543    k = np.array(
544        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
545        dtype=np.float32)
546    attention_layer = dense_attention.AdditiveAttention()
547    attention_layer.build(input_shape=([1, 2, 4], [1, 3, 4]))
548    # Scale tensor of shape [4]
549    attention_layer.scale = np.array([[[0.5, 0.6, 0.7, 0.8]]], dtype=np.float32)
550    actual = attention_layer._calculate_scores(query=q, key=k)
551
552    # pylint:disable=line-too-long
553    # expected000 = 0.5*tanh(1.+1.5) + 0.6*tanh(1.1+1.6) + 0.7*tanh(1.2+1.7) + 0.8*tanh(1.3+1.8) = 2.58044532581
554    # expected001 = 0.5*tanh(1.+2.5) + 0.6*tanh(1.1+2.6) + 0.7*tanh(1.2+2.7) + 0.8*tanh(1.3+2.8) = 2.59734317449
555    # expected002 = 0.5*tanh(1.+3.5) + 0.6*tanh(1.1+3.6) + 0.7*tanh(1.2+3.7) + 0.8*tanh(1.3+3.8) = 2.59964024652
556    # expected010 = 0.5*tanh(2.+1.5) + 0.6*tanh(2.1+1.6) + 0.7*tanh(2.2+1.7) + 0.8*tanh(2.3+1.8) = 2.59734317449
557    # expected011 = 0.5*tanh(2.+2.5) + 0.6*tanh(2.1+2.6) + 0.7*tanh(2.2+2.7) + 0.8*tanh(2.3+2.8) = 2.59964024652
558    # expected012 = 0.5*tanh(2.+3.5) + 0.6*tanh(2.1+3.6) + 0.7*tanh(2.2+3.7) + 0.8*tanh(2.3+3.8) = 2.59995130916
559    # pylint:enable=line-too-long
560    expected = np.array([[[2.58044532581, 2.59734317449, 2.59964024652],
561                          [2.59734317449, 2.59964024652, 2.59995130916]]],
562                        dtype=np.float32)
563    self.assertAllClose(expected, actual)
564
565  def test_calculate_scores_one_dim_batch_size_two(self):
566    # Query tensor of shape [2, 1, 1]
567    q = np.array([[[1.1]], [[2.1]]], dtype=np.float32)
568    # Key tensor of shape [2, 1, 1]
569    k = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
570    attention_layer = dense_attention.AdditiveAttention()
571    attention_layer.build(input_shape=([2, 1, 1], [2, 1, 1]))
572    # Scale tensor of shape [1]
573    attention_layer.scale = np.array([[[0.5]]], dtype=np.float32)
574    actual = attention_layer._calculate_scores(query=q, key=k)
575
576    # Expected tensor of shape [2, 1, 1].
577    # expected000 = 0.5 * tanh(1.1 + 1.6) = 0.49550372683
578    # expected100 = 0.5 * tanh(2.1 + 2.6) = 0.49991728277
579    expected = np.array([[[0.49550372683]], [[0.49991728277]]],
580                        dtype=np.float32)
581    self.assertAllClose(expected, actual)
582
583  def test_shape(self):
584    # Query tensor of shape [1, 2, 4]
585    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
586    # Value tensor of shape [1, 3, 4]
587    v = np.array(
588        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
589        dtype=np.float32)
590    # Value mask tensor of shape [1, 3]
591    v_mask = np.array([[True, True, False]], dtype=np.bool_)
592    attention_layer = dense_attention.AdditiveAttention()
593    actual = attention_layer([q, v], mask=[None, v_mask])
594
595    expected_shape = [1, 2, 4]
596    self.assertAllEqual(expected_shape, array_ops.shape(actual))
597
598  def test_shape_no_scale(self):
599    # Query tensor of shape [1, 2, 4]
600    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
601    # Value tensor of shape [1, 3, 4]
602    v = np.array(
603        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
604        dtype=np.float32)
605    # Value mask tensor of shape [1, 3]
606    v_mask = np.array([[True, True, False]], dtype=np.bool_)
607    attention_layer = dense_attention.AdditiveAttention(use_scale=False)
608    actual = attention_layer([q, v], mask=[None, v_mask])
609
610    expected_shape = [1, 2, 4]
611    self.assertAllEqual(expected_shape, array_ops.shape(actual))
612
613  def test_shape_with_key(self):
614    # Query tensor of shape [1, 2, 4]
615    q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
616    # Value tensor of shape [1, 3, 4]
617    v = np.array(
618        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
619        dtype=np.float32)
620    # Key tensor of shape [1, 3, 4]
621    k = np.array(
622        [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
623        dtype=np.float32)
624    # Value mask tensor of shape [1, 3]
625    v_mask = np.array([[True, True, False]], dtype=np.bool_)
626    attention_layer = dense_attention.AdditiveAttention()
627    actual = attention_layer([q, v, k], mask=[None, v_mask])
628
629    expected_shape = [1, 2, 4]
630    self.assertAllEqual(expected_shape, array_ops.shape(actual))
631
632  def test_multi_dim(self):
633    # Query tensor of shape [1, 1, 1]
634    q = np.array([[[1.1]]], dtype=np.float32)
635    # Value tensor of shape [1, 3, 1]
636    v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
637    # Value mask tensor of shape [1, 3]
638    v_mask = np.array([[True, True, False]], dtype=np.bool_)
639    attention_layer = dense_attention.AdditiveAttention()
640    attention_layer.build(input_shape=([1, 1, 1], [1, 3, 1]))
641    # Scale tensor of shape [1]
642    attention_layer.scale = np.array([[[0.5]]], dtype=np.float32)
643    actual = attention_layer([q, v], mask=[None, v_mask])
644
645    # pylint:disable=line-too-long
646    # Expected scores of shape [1, 1, 3]
647    # scores = [[[0.5 * tanh(1.1 + 1.6), 0.5 * tanh(1.1 + 0.7), 0.5 * tanh(1.1 - 0.8)]]]
648    #        = [[[0.49550372683, 0.47340300642, 0.14565630622]]]
649    # Expected attention distribution = softmax(scores) with zeros in
650    # positions where v_mask == False.
651    # => attention_distribution000
652    #      = exp(0.49550372683)/(exp(0.49550372683) + exp(0.47340300642))
653    #      = 0.50552495521
654    #    attention_distribution001
655    #      = exp(0.47340300642)/(exp(0.49550372683) + exp(0.47340300642))
656    #      = 0.49447504478
657    #    attention_distribution002 = 0
658    #
659    # Expected tensor of shape [1, 1, 1].
660    # expected000 = 0.50552495521 * 1.6 + 0.49447504478 * 0.7 - 0 * 0.8
661    #             = 1.15497245968
662    # pylint:enable=line-too-long
663    expected = np.array([[[1.15497245968]]], dtype=np.float32)
664    self.assertAllClose(expected, actual)
665
666  def test_multi_dim_with_key(self):
667    # Query tensor of shape [1, 1, 1]
668    q = np.array([[[1.1]]], dtype=np.float32)
669    # Value tensor of shape [1, 3, 1]
670    v = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32)
671    # Key tensor of shape [1, 3, 1]
672    k = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
673    # Value mask tensor of shape [1, 3]
674    v_mask = np.array([[True, True, False]], dtype=np.bool_)
675    attention_layer = dense_attention.AdditiveAttention()
676    attention_layer.build(input_shape=([1, 1, 1], [1, 3, 1]))
677    # Scale tensor of shape [1]
678    attention_layer.scale = np.array([[[0.5]]], dtype=np.float32)
679    actual = attention_layer([q, v, k], mask=[None, v_mask])
680
681    # pylint:disable=line-too-long
682    # Expected scores of shape [1, 1, 3]
683    # scores = [[[0.5 * tanh(1.1 + 1.6), 0.5 * tanh(1.1 + 0.7), 0.5 * tanh(1.1 - 0.8)]]]
684    #        = [[[0.49550372683, 0.47340300642, 0.14565630622]]]
685    # Expected attention distribution = softmax(scores) with zeros in
686    # positions where v_mask == False.
687    # => attention_distribution000
688    #        = exp(0.49550372683)/(exp(0.49550372683) + exp(0.47340300642))
689    #        = 0.50552495521
690    #    attention_distribution001
691    #        = exp(0.47340300642)/(exp(0.49550372683) + exp(0.47340300642))
692    #        = 0.49447504478
693    #    attention_distribution002 = 0
694    #
695    # Expected tensor of shape [1, 1, 1].
696    # expected000 = 0.50552495521 * 0.5 + 0.49447504478 * 0.8 - 0 * 0.3
697    #             = 0.64834251342
698    # pylint:enable=line-too-long
699    expected = np.array([[[0.64834251342]]], dtype=np.float32)
700    self.assertAllClose(expected, actual)
701
702  def test_multi_dim_with_query_mask(self):
703    # Query tensor of shape [1, 2, 1]
704    q = np.array([[[1.1], [-0.5]]], dtype=np.float32)
705    # Value tensor of shape [1, 3, 1]
706    v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
707    # Query mask tensor of shape [1, 2]
708    q_mask = np.array([[True, False]], dtype=np.bool_)
709    # Value mask tensor of shape [1, 3]
710    v_mask = np.array([[True, True, False]], dtype=np.bool_)
711    attention_layer = dense_attention.AdditiveAttention()
712    attention_layer.build(input_shape=([1, 1, 1], [1, 3, 1]))
713    # Scale tensor of shape [1]
714    attention_layer.scale = np.array([[[0.5]]], dtype=np.float32)
715    actual = attention_layer([q, v], mask=[q_mask, v_mask])
716
717    # pylint:disable=line-too-long
718    # Expected scores of shape [1, 2, 3]
719    # scores = [[[0.5 * tanh(1.1 + 1.6), 0.5 * tanh(1.1 + 0.7), 0.5 * tanh(1.1 - 0.8)],
720    #            [0.5 * tanh(-0.5 + 1.6), 0.5 * tanh(-0.5 + 0.7), 0.5 * tanh(-0.5 - 0.8)]]]
721    #        = [[[0.49550372683, 0.47340300642, 0.14565630622],
722    #            [0.40024951088, 0.09868766011, -0.43086157965]]]
723    # Expected attention distribution = softmax(scores) with zeros in
724    # positions where v_mask == False.
725    # => attention_distribution000
726    #        = exp(0.49550372683)/(exp(0.49550372683) + exp(0.47340300642))
727    #        = 0.50552495521
728    #    attention_distribution001
729    #        = exp(0.47340300642)/(exp(0.49550372683) + exp(0.47340300642))
730    #        = 0.49447504478
731    #    attention_distribution002 = 0
732    # => attention_distribution010
733    #        = exp(0.40024951088)/(exp(0.40024951088) + exp(0.09868766011))
734    #        = 0.57482427975
735    #    attention_distribution011
736    #        = exp(0.09868766011)/(exp(0.40024951088) + exp(0.09868766011))
737    #        = 0.42517572025
738    #    attention_distribution012 = 0
739    #
740    # Expected tensor of shape [1, 2, 1] with zeros where  q_mask == False.
741    # expected000 = 0.50552495521 * 1.6 + 0.49447504478 * 0.7 - 0 * 0.8
742    #             = 1.15497245968
743    # expected000 = 0
744    # pylint:enable=line-too-long
745    expected = np.array([[[1.15497245968], [0.]]], dtype=np.float32)
746    self.assertAllClose(expected, actual)
747
748  def test_serialization(self):
749    # Test serialization with use_scale
750    layer = dense_attention.AdditiveAttention(use_scale=True)
751
752    config = keras.layers.serialize(layer)
753    new_layer = keras.layers.deserialize(config)
754    self.assertEqual(new_layer.use_scale, True)
755
756    config = layer.get_config()
757    new_layer = dense_attention.AdditiveAttention.from_config(config)
758    self.assertEqual(new_layer.use_scale, True)
759
760  @testing_utils.enable_v2_dtype_behavior
761  def test_mixed_float16_policy(self):
762    # Test case for GitHub issue:
763    # https://github.com/tensorflow/tensorflow/issues/46064
764    with policy.policy_scope('mixed_float16'):
765      q = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=1), 'float16')
766      v = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=2), 'float16')
767      k = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=3), 'float16')
768      layer = dense_attention.AdditiveAttention(causal=True)
769      _ = layer([q, v, k])
770
771
772@combinations.generate(combinations.combine(mode=['graph', 'eager']))
773class LowerTriangularMaskTest(test.TestCase, parameterized.TestCase):
774
775  def test_square_shape(self):
776    actual = dense_attention._lower_triangular_mask([3, 3])
777    expected = np.array(
778        [[True, False, False], [True, True, False], [True, True, True]],
779        dtype=np.bool_)
780    self.assertAllEqual(expected, actual)
781
782  def test_orthogonal_shape(self):
783    actual = dense_attention._lower_triangular_mask([3, 2])
784    expected = np.array([[True, False], [True, True], [True, True]],
785                        dtype=np.bool_)
786    self.assertAllEqual(expected, actual)
787
788  def test_three_dim(self):
789    actual = dense_attention._lower_triangular_mask([1, 3, 3])
790    expected = np.array(
791        [[[True, False, False], [True, True, False], [True, True, True]]],
792        dtype=np.bool_)
793    self.assertAllEqual(expected, actual)
794
795
796if __name__ == '__main__':
797  test.main()
798