• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for merge layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python import keras
25from tensorflow.python.framework import test_util as tf_test_util
26from tensorflow.python.keras import backend as K
27from tensorflow.python.keras import keras_parameterized
28from tensorflow.python.keras import testing_utils
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.ops.ragged import ragged_factory_ops
31from tensorflow.python.platform import test
32
33
34@keras_parameterized.run_all_keras_modes
35class MergeLayersTest(keras_parameterized.TestCase):
36
37  def test_merge_add(self):
38    i1 = keras.layers.Input(shape=(4, 5))
39    i2 = keras.layers.Input(shape=(4, 5))
40    i3 = keras.layers.Input(shape=(4, 5))
41
42    add_layer = keras.layers.Add()
43    o = add_layer([i1, i2, i3])
44    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
45    model = keras.models.Model([i1, i2, i3], o)
46    model.run_eagerly = testing_utils.should_run_eagerly()
47    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
48
49    x1 = np.random.random((2, 4, 5))
50    x2 = np.random.random((2, 4, 5))
51    x3 = np.random.random((2, 4, 5))
52    out = model.predict([x1, x2, x3])
53    self.assertEqual(out.shape, (2, 4, 5))
54    self.assertAllClose(out, x1 + x2 + x3, atol=1e-4)
55
56    self.assertEqual(
57        add_layer.compute_mask([i1, i2, i3], [None, None, None]), None)
58    self.assertTrue(
59        np.all(
60            K.eval(
61                add_layer.compute_mask(
62                    [i1, i2], [K.variable(x1), K.variable(x2)]))))
63
64    with self.assertRaisesRegexp(ValueError, '`mask` should be a list.'):
65      add_layer.compute_mask([i1, i2, i3], x1)
66    with self.assertRaisesRegexp(ValueError, '`inputs` should be a list.'):
67      add_layer.compute_mask(i1, [None, None, None])
68    with self.assertRaisesRegexp(ValueError, ' should have the same length.'):
69      add_layer.compute_mask([i1, i2, i3], [None, None])
70
71  def test_merge_subtract(self):
72    i1 = keras.layers.Input(shape=(4, 5))
73    i2 = keras.layers.Input(shape=(4, 5))
74    i3 = keras.layers.Input(shape=(4, 5))
75
76    subtract_layer = keras.layers.Subtract()
77    o = subtract_layer([i1, i2])
78    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
79    model = keras.models.Model([i1, i2], o)
80    model.run_eagerly = testing_utils.should_run_eagerly()
81    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
82
83    x1 = np.random.random((2, 4, 5))
84    x2 = np.random.random((2, 4, 5))
85    out = model.predict([x1, x2])
86    self.assertEqual(out.shape, (2, 4, 5))
87    self.assertAllClose(out, x1 - x2, atol=1e-4)
88
89    self.assertEqual(subtract_layer.compute_mask([i1, i2], [None, None]), None)
90    self.assertTrue(
91        np.all(
92            K.eval(
93                subtract_layer.compute_mask(
94                    [i1, i2], [K.variable(x1), K.variable(x2)]))))
95
96    with self.assertRaisesRegexp(ValueError, '`mask` should be a list.'):
97      subtract_layer.compute_mask([i1, i2], x1)
98    with self.assertRaisesRegexp(ValueError, '`inputs` should be a list.'):
99      subtract_layer.compute_mask(i1, [None, None])
100    with self.assertRaisesRegexp(ValueError,
101                                 'layer should be called on exactly 2 inputs'):
102      subtract_layer([i1, i2, i3])
103    with self.assertRaisesRegexp(ValueError,
104                                 'layer should be called on exactly 2 inputs'):
105      subtract_layer([i1])
106
107  def test_merge_multiply(self):
108    i1 = keras.layers.Input(shape=(4, 5))
109    i2 = keras.layers.Input(shape=(4, 5))
110    i3 = keras.layers.Input(shape=(4, 5))
111    o = keras.layers.multiply([i1, i2, i3])
112    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
113    model = keras.models.Model([i1, i2, i3], o)
114    model.run_eagerly = testing_utils.should_run_eagerly()
115    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
116
117    x1 = np.random.random((2, 4, 5))
118    x2 = np.random.random((2, 4, 5))
119    x3 = np.random.random((2, 4, 5))
120    out = model.predict([x1, x2, x3])
121    self.assertEqual(out.shape, (2, 4, 5))
122    self.assertAllClose(out, x1 * x2 * x3, atol=1e-4)
123
124  def test_merge_average(self):
125    i1 = keras.layers.Input(shape=(4, 5))
126    i2 = keras.layers.Input(shape=(4, 5))
127    o = keras.layers.average([i1, i2])
128    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
129    model = keras.models.Model([i1, i2], o)
130    model.run_eagerly = testing_utils.should_run_eagerly()
131    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
132
133    x1 = np.random.random((2, 4, 5))
134    x2 = np.random.random((2, 4, 5))
135    out = model.predict([x1, x2])
136    self.assertEqual(out.shape, (2, 4, 5))
137    self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4)
138
139  def test_merge_maximum(self):
140    i1 = keras.layers.Input(shape=(4, 5))
141    i2 = keras.layers.Input(shape=(4, 5))
142    o = keras.layers.maximum([i1, i2])
143    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
144    model = keras.models.Model([i1, i2], o)
145    model.run_eagerly = testing_utils.should_run_eagerly()
146    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
147
148    x1 = np.random.random((2, 4, 5))
149    x2 = np.random.random((2, 4, 5))
150    out = model.predict([x1, x2])
151    self.assertEqual(out.shape, (2, 4, 5))
152    self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4)
153
154  def test_merge_minimum(self):
155    i1 = keras.layers.Input(shape=(4, 5))
156    i2 = keras.layers.Input(shape=(4, 5))
157    o = keras.layers.minimum([i1, i2])
158    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
159    model = keras.models.Model([i1, i2], o)
160    model.run_eagerly = testing_utils.should_run_eagerly()
161    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
162
163    x1 = np.random.random((2, 4, 5))
164    x2 = np.random.random((2, 4, 5))
165    out = model.predict([x1, x2])
166    self.assertEqual(out.shape, (2, 4, 5))
167    self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4)
168
169  def test_merge_concatenate(self):
170    i1 = keras.layers.Input(shape=(4, 5))
171    i2 = keras.layers.Input(shape=(4, 5))
172    concat_layer = keras.layers.Concatenate(axis=1)
173    o = concat_layer([i1, i2])
174    self.assertListEqual(o.shape.as_list(), [None, 8, 5])
175    model = keras.models.Model([i1, i2], o)
176    model.run_eagerly = testing_utils.should_run_eagerly()
177    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
178
179    x1 = np.random.random((2, 4, 5))
180    x2 = np.random.random((2, 4, 5))
181    out = model.predict([x1, x2])
182    self.assertEqual(out.shape, (2, 8, 5))
183    self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
184
185    self.assertEqual(concat_layer.compute_mask([i1, i2], [None, None]), None)
186    self.assertTrue(
187        np.all(
188            K.eval(
189                concat_layer.compute_mask(
190                    [i1, i2], [K.variable(x1), K.variable(x2)]))))
191
192    with self.assertRaisesRegexp(ValueError, '`mask` should be a list.'):
193      concat_layer.compute_mask([i1, i2], x1)
194    with self.assertRaisesRegexp(ValueError, '`inputs` should be a list.'):
195      concat_layer.compute_mask(i1, [None, None])
196    with self.assertRaisesRegexp(ValueError, 'should have the same length'):
197      concat_layer.compute_mask([i1, i2], [None])
198    with self.assertRaisesRegexp(ValueError,
199                                 'layer should be called on a list of inputs'):
200      concat_layer(i1)
201
202  def test_merge_dot(self):
203    i1 = keras.layers.Input(shape=(4,))
204    i2 = keras.layers.Input(shape=(4,))
205    o = keras.layers.dot([i1, i2], axes=1)
206    self.assertListEqual(o.shape.as_list(), [None, 1])
207    model = keras.models.Model([i1, i2], o)
208    model.run_eagerly = testing_utils.should_run_eagerly()
209    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
210    _ = keras.layers.Dot(axes=1).get_config()
211
212    x1 = np.random.random((2, 4))
213    x2 = np.random.random((2, 4))
214    out = model.predict([x1, x2])
215    self.assertEqual(out.shape, (2, 1))
216    expected = np.zeros((2, 1))
217    expected[0, 0] = np.dot(x1[0], x2[0])
218    expected[1, 0] = np.dot(x1[1], x2[1])
219    self.assertAllClose(out, expected, atol=1e-4)
220
221    # Test with negative tuple of axes.
222    o = keras.layers.dot([i1, i2], axes=(-1, -1))
223    self.assertListEqual(o.shape.as_list(), [None, 1])
224    model = keras.models.Model([i1, i2], o)
225    model.run_eagerly = testing_utils.should_run_eagerly()
226    model._experimental_run_tf_function = testing_utils.should_run_tf_function()
227    out = model.predict([x1, x2])
228    self.assertEqual(out.shape, (2, 1))
229    self.assertAllClose(out, expected, atol=1e-4)
230
231    # test compute_output_shape
232    layer = keras.layers.Dot(axes=-1)
233    self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1))
234
235  @parameterized.named_parameters(
236      *tf_test_util.generate_combinations_with_testcase_name(
237          layer=[keras.layers.Add, keras.layers.Subtract,
238                 keras.layers.Multiply, keras.layers.Minimum,
239                 keras.layers.Maximum, keras.layers.Average,
240                 keras.layers.Concatenate]))
241  def test_merge_with_ragged_input(self, layer):
242    ragged_data = ragged_factory_ops.constant(
243        [[1., 1., 1.], [1., 1.], [1., 1., 1., 1.]], ragged_rank=1)
244    dense_data = ragged_data.to_tensor()
245    input1 = keras.Input(shape=(None,), ragged=True)
246    input2 = keras.Input(shape=(None,), ragged=True)
247    out = keras.layers.Add()([input1, input2])
248    model = keras.models.Model(inputs=[input1, input2], outputs=out)
249    out_ragged = model.predict([ragged_data, ragged_data], steps=1)
250    out_ragged = ragged_tensor.convert_to_tensor_or_ragged_tensor(
251        out_ragged).to_tensor()
252
253    input1 = keras.Input(shape=(None,))
254    input2 = keras.Input(shape=(None,))
255    out = keras.layers.Add()([input1, input2])
256    model = keras.models.Model(inputs=[input1, input2], outputs=out)
257    out_dense = model.predict([dense_data, dense_data], steps=1)
258
259    self.assertAllEqual(out_dense, out_ragged)
260
261
262@tf_test_util.run_all_in_graph_and_eager_modes
263class MergeLayersTestNoExecution(test.TestCase):
264
265  def test_merge_elementwise_errors(self):
266    i1 = keras.layers.Input(shape=(4, 5))
267    i2 = keras.layers.Input(shape=(4, 6))
268    with self.assertRaises(ValueError):
269      keras.layers.add([i1, i2])
270    with self.assertRaises(ValueError):
271      keras.layers.add([i1])
272    with self.assertRaises(ValueError):
273      keras.layers.add(i1)
274    with self.assertRaises(ValueError):
275      keras.layers.add([i1])
276
277  def test_concatenate_errors(self):
278    i1 = keras.layers.Input(shape=(4, 5))
279    i2 = keras.layers.Input(shape=(3, 5))
280    with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'):
281      keras.layers.concatenate([i1, i2], axis=-1)
282    with self.assertRaisesRegexp(ValueError, 'called on a list'):
283      keras.layers.concatenate(i1, axis=-1)
284    with self.assertRaisesRegexp(ValueError, 'called on a list'):
285      keras.layers.concatenate([i1], axis=-1)
286
287  def test_concatenate_with_partial_shape(self):
288    i1 = keras.layers.Input(shape=(5,), batch_size=32)
289    i2 = keras.layers.Input(shape=(5,))
290    i3 = keras.layers.Input(shape=(4, 5), batch_size=32)
291    i4 = keras.layers.Input(shape=(None,), batch_size=64)
292    i5 = keras.layers.Input(shape=(7,))
293
294    # Valid case since the i2 has a dynamic batch size.
295    keras.layers.concatenate([i1, i2], axis=-1)
296
297    # Different rank
298    with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'):
299      keras.layers.concatenate([i1, i3], axis=-1)
300
301    # Valid case with partial dimension information
302    keras.layers.concatenate([i1, i4], axis=0)
303    keras.layers.concatenate([i2, i4], axis=0)
304    keras.layers.concatenate([i2, i4], axis=1)
305    keras.layers.concatenate([i1, i2, i4], axis=0)
306    keras.layers.concatenate([i1, i5], axis=1)
307
308    # Mismatch in batch dimension.
309    with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'):
310      keras.layers.concatenate([i1, i4], axis=-1)
311
312    with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'):
313      keras.layers.concatenate([i1, i2, i4], axis=-1)
314
315  def test_dot_errors(self):
316    i1 = keras.layers.Input(shape=(4, 5))
317    i2 = keras.layers.Input(shape=(4, 6))
318    i3 = keras.layers.Input(shape=(4, 6))
319    with self.assertRaises(ValueError):
320      keras.layers.dot([i1, i2], axes=-1)
321    with self.assertRaises(ValueError):
322      keras.layers.dot(i1, axes=-1)
323    with self.assertRaises(ValueError):
324      keras.layers.dot([i1], axes=-1)
325    with self.assertRaises(ValueError):
326      keras.layers.dot([i1, i2, i3], axes=-1)
327    with self.assertRaises(ValueError):
328      dot = keras.layers.Dot(1)
329      dot.compute_output_shape(1)
330
331  def test_merge_subtract(self):
332    i1 = keras.layers.Input(shape=(4, 5))
333    i2 = keras.layers.Input(shape=(4, 5))
334    y = keras.layers.subtract([i1, i2])
335    self.assertEqual(y.shape.as_list(), [None, 4, 5])
336
337    # Test invalid use cases
338    i1 = keras.layers.Input(shape=(4, 5))
339    i2 = keras.layers.Input(shape=(3, 5))
340    with self.assertRaises(ValueError):
341      keras.layers.subtract([i1, i2])
342    with self.assertRaises(ValueError):
343      keras.layers.subtract([i1, i1, i1])
344
345  def test_merge_add_masking(self):
346    i1 = keras.layers.Input(shape=(4, 5))
347    i2 = keras.layers.Input(shape=(4, 5))
348    m1 = keras.layers.Masking()(i1)
349    layer = keras.layers.Add()
350    o = layer([m1, i2])
351    self.assertListEqual(o.shape.as_list(), [None, 4, 5])
352    mask = layer.output_mask
353    self.assertListEqual(mask.shape.as_list(), [None, 4])
354
355  def test_merge_add_dynamic_shape(self):
356    i1 = keras.Input(batch_shape=(4, None), dtype='float32')
357    i2 = keras.Input(batch_shape=(4, 5), dtype='float32')
358    layer = keras.layers.Add()
359    o = layer([i1, i2])
360    self.assertListEqual(o.shape.as_list(), [4, 5])
361
362  def test_merge_concatenate_masking(self):
363    i1 = keras.layers.Input(shape=(4, 5))
364    i2 = keras.layers.Input(shape=(4, 5))
365    m1 = keras.layers.Masking()(i1)
366    layer = keras.layers.Concatenate()
367    o = layer([m1, i2])
368    self.assertListEqual(o.shape.as_list(), [None, 4, 10])
369    mask = layer.output_mask
370    self.assertListEqual(mask.shape.as_list(), [None, 4])
371
372
373if __name__ == '__main__':
374  test.main()
375