1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for merge layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python import keras 25from tensorflow.python.framework import test_util as tf_test_util 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras import keras_parameterized 28from tensorflow.python.keras import testing_utils 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.ops.ragged import ragged_factory_ops 31from tensorflow.python.platform import test 32 33 34@keras_parameterized.run_all_keras_modes 35class MergeLayersTest(keras_parameterized.TestCase): 36 37 def test_merge_add(self): 38 i1 = keras.layers.Input(shape=(4, 5)) 39 i2 = keras.layers.Input(shape=(4, 5)) 40 i3 = keras.layers.Input(shape=(4, 5)) 41 42 add_layer = keras.layers.Add() 43 o = add_layer([i1, i2, i3]) 44 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 45 model = keras.models.Model([i1, i2, i3], o) 46 model.run_eagerly = testing_utils.should_run_eagerly() 47 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 48 49 x1 = np.random.random((2, 4, 5)) 50 x2 = np.random.random((2, 4, 5)) 51 x3 = np.random.random((2, 4, 5)) 52 out = model.predict([x1, x2, x3]) 53 self.assertEqual(out.shape, (2, 4, 5)) 54 self.assertAllClose(out, x1 + x2 + x3, atol=1e-4) 55 56 self.assertEqual( 57 add_layer.compute_mask([i1, i2, i3], [None, None, None]), None) 58 self.assertTrue( 59 np.all( 60 K.eval( 61 add_layer.compute_mask( 62 [i1, i2], [K.variable(x1), K.variable(x2)])))) 63 64 with self.assertRaisesRegexp(ValueError, '`mask` should be a list.'): 65 add_layer.compute_mask([i1, i2, i3], x1) 66 with self.assertRaisesRegexp(ValueError, '`inputs` should be a list.'): 67 add_layer.compute_mask(i1, [None, None, None]) 68 with self.assertRaisesRegexp(ValueError, ' should have the same length.'): 69 add_layer.compute_mask([i1, i2, i3], [None, None]) 70 71 def test_merge_subtract(self): 72 i1 = keras.layers.Input(shape=(4, 5)) 73 i2 = keras.layers.Input(shape=(4, 5)) 74 i3 = keras.layers.Input(shape=(4, 5)) 75 76 subtract_layer = keras.layers.Subtract() 77 o = subtract_layer([i1, i2]) 78 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 79 model = keras.models.Model([i1, i2], o) 80 model.run_eagerly = testing_utils.should_run_eagerly() 81 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 82 83 x1 = np.random.random((2, 4, 5)) 84 x2 = np.random.random((2, 4, 5)) 85 out = model.predict([x1, x2]) 86 self.assertEqual(out.shape, (2, 4, 5)) 87 self.assertAllClose(out, x1 - x2, atol=1e-4) 88 89 self.assertEqual(subtract_layer.compute_mask([i1, i2], [None, None]), None) 90 self.assertTrue( 91 np.all( 92 K.eval( 93 subtract_layer.compute_mask( 94 [i1, i2], [K.variable(x1), K.variable(x2)])))) 95 96 with self.assertRaisesRegexp(ValueError, '`mask` should be a list.'): 97 subtract_layer.compute_mask([i1, i2], x1) 98 with self.assertRaisesRegexp(ValueError, '`inputs` should be a list.'): 99 subtract_layer.compute_mask(i1, [None, None]) 100 with self.assertRaisesRegexp(ValueError, 101 'layer should be called on exactly 2 inputs'): 102 subtract_layer([i1, i2, i3]) 103 with self.assertRaisesRegexp(ValueError, 104 'layer should be called on exactly 2 inputs'): 105 subtract_layer([i1]) 106 107 def test_merge_multiply(self): 108 i1 = keras.layers.Input(shape=(4, 5)) 109 i2 = keras.layers.Input(shape=(4, 5)) 110 i3 = keras.layers.Input(shape=(4, 5)) 111 o = keras.layers.multiply([i1, i2, i3]) 112 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 113 model = keras.models.Model([i1, i2, i3], o) 114 model.run_eagerly = testing_utils.should_run_eagerly() 115 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 116 117 x1 = np.random.random((2, 4, 5)) 118 x2 = np.random.random((2, 4, 5)) 119 x3 = np.random.random((2, 4, 5)) 120 out = model.predict([x1, x2, x3]) 121 self.assertEqual(out.shape, (2, 4, 5)) 122 self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) 123 124 def test_merge_average(self): 125 i1 = keras.layers.Input(shape=(4, 5)) 126 i2 = keras.layers.Input(shape=(4, 5)) 127 o = keras.layers.average([i1, i2]) 128 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 129 model = keras.models.Model([i1, i2], o) 130 model.run_eagerly = testing_utils.should_run_eagerly() 131 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 132 133 x1 = np.random.random((2, 4, 5)) 134 x2 = np.random.random((2, 4, 5)) 135 out = model.predict([x1, x2]) 136 self.assertEqual(out.shape, (2, 4, 5)) 137 self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) 138 139 def test_merge_maximum(self): 140 i1 = keras.layers.Input(shape=(4, 5)) 141 i2 = keras.layers.Input(shape=(4, 5)) 142 o = keras.layers.maximum([i1, i2]) 143 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 144 model = keras.models.Model([i1, i2], o) 145 model.run_eagerly = testing_utils.should_run_eagerly() 146 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 147 148 x1 = np.random.random((2, 4, 5)) 149 x2 = np.random.random((2, 4, 5)) 150 out = model.predict([x1, x2]) 151 self.assertEqual(out.shape, (2, 4, 5)) 152 self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) 153 154 def test_merge_minimum(self): 155 i1 = keras.layers.Input(shape=(4, 5)) 156 i2 = keras.layers.Input(shape=(4, 5)) 157 o = keras.layers.minimum([i1, i2]) 158 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 159 model = keras.models.Model([i1, i2], o) 160 model.run_eagerly = testing_utils.should_run_eagerly() 161 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 162 163 x1 = np.random.random((2, 4, 5)) 164 x2 = np.random.random((2, 4, 5)) 165 out = model.predict([x1, x2]) 166 self.assertEqual(out.shape, (2, 4, 5)) 167 self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) 168 169 def test_merge_concatenate(self): 170 i1 = keras.layers.Input(shape=(4, 5)) 171 i2 = keras.layers.Input(shape=(4, 5)) 172 concat_layer = keras.layers.Concatenate(axis=1) 173 o = concat_layer([i1, i2]) 174 self.assertListEqual(o.shape.as_list(), [None, 8, 5]) 175 model = keras.models.Model([i1, i2], o) 176 model.run_eagerly = testing_utils.should_run_eagerly() 177 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 178 179 x1 = np.random.random((2, 4, 5)) 180 x2 = np.random.random((2, 4, 5)) 181 out = model.predict([x1, x2]) 182 self.assertEqual(out.shape, (2, 8, 5)) 183 self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4) 184 185 self.assertEqual(concat_layer.compute_mask([i1, i2], [None, None]), None) 186 self.assertTrue( 187 np.all( 188 K.eval( 189 concat_layer.compute_mask( 190 [i1, i2], [K.variable(x1), K.variable(x2)])))) 191 192 with self.assertRaisesRegexp(ValueError, '`mask` should be a list.'): 193 concat_layer.compute_mask([i1, i2], x1) 194 with self.assertRaisesRegexp(ValueError, '`inputs` should be a list.'): 195 concat_layer.compute_mask(i1, [None, None]) 196 with self.assertRaisesRegexp(ValueError, 'should have the same length'): 197 concat_layer.compute_mask([i1, i2], [None]) 198 with self.assertRaisesRegexp(ValueError, 199 'layer should be called on a list of inputs'): 200 concat_layer(i1) 201 202 def test_merge_dot(self): 203 i1 = keras.layers.Input(shape=(4,)) 204 i2 = keras.layers.Input(shape=(4,)) 205 o = keras.layers.dot([i1, i2], axes=1) 206 self.assertListEqual(o.shape.as_list(), [None, 1]) 207 model = keras.models.Model([i1, i2], o) 208 model.run_eagerly = testing_utils.should_run_eagerly() 209 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 210 _ = keras.layers.Dot(axes=1).get_config() 211 212 x1 = np.random.random((2, 4)) 213 x2 = np.random.random((2, 4)) 214 out = model.predict([x1, x2]) 215 self.assertEqual(out.shape, (2, 1)) 216 expected = np.zeros((2, 1)) 217 expected[0, 0] = np.dot(x1[0], x2[0]) 218 expected[1, 0] = np.dot(x1[1], x2[1]) 219 self.assertAllClose(out, expected, atol=1e-4) 220 221 # Test with negative tuple of axes. 222 o = keras.layers.dot([i1, i2], axes=(-1, -1)) 223 self.assertListEqual(o.shape.as_list(), [None, 1]) 224 model = keras.models.Model([i1, i2], o) 225 model.run_eagerly = testing_utils.should_run_eagerly() 226 model._experimental_run_tf_function = testing_utils.should_run_tf_function() 227 out = model.predict([x1, x2]) 228 self.assertEqual(out.shape, (2, 1)) 229 self.assertAllClose(out, expected, atol=1e-4) 230 231 # test compute_output_shape 232 layer = keras.layers.Dot(axes=-1) 233 self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1)) 234 235 @parameterized.named_parameters( 236 *tf_test_util.generate_combinations_with_testcase_name( 237 layer=[keras.layers.Add, keras.layers.Subtract, 238 keras.layers.Multiply, keras.layers.Minimum, 239 keras.layers.Maximum, keras.layers.Average, 240 keras.layers.Concatenate])) 241 def test_merge_with_ragged_input(self, layer): 242 ragged_data = ragged_factory_ops.constant( 243 [[1., 1., 1.], [1., 1.], [1., 1., 1., 1.]], ragged_rank=1) 244 dense_data = ragged_data.to_tensor() 245 input1 = keras.Input(shape=(None,), ragged=True) 246 input2 = keras.Input(shape=(None,), ragged=True) 247 out = keras.layers.Add()([input1, input2]) 248 model = keras.models.Model(inputs=[input1, input2], outputs=out) 249 out_ragged = model.predict([ragged_data, ragged_data], steps=1) 250 out_ragged = ragged_tensor.convert_to_tensor_or_ragged_tensor( 251 out_ragged).to_tensor() 252 253 input1 = keras.Input(shape=(None,)) 254 input2 = keras.Input(shape=(None,)) 255 out = keras.layers.Add()([input1, input2]) 256 model = keras.models.Model(inputs=[input1, input2], outputs=out) 257 out_dense = model.predict([dense_data, dense_data], steps=1) 258 259 self.assertAllEqual(out_dense, out_ragged) 260 261 262@tf_test_util.run_all_in_graph_and_eager_modes 263class MergeLayersTestNoExecution(test.TestCase): 264 265 def test_merge_elementwise_errors(self): 266 i1 = keras.layers.Input(shape=(4, 5)) 267 i2 = keras.layers.Input(shape=(4, 6)) 268 with self.assertRaises(ValueError): 269 keras.layers.add([i1, i2]) 270 with self.assertRaises(ValueError): 271 keras.layers.add([i1]) 272 with self.assertRaises(ValueError): 273 keras.layers.add(i1) 274 with self.assertRaises(ValueError): 275 keras.layers.add([i1]) 276 277 def test_concatenate_errors(self): 278 i1 = keras.layers.Input(shape=(4, 5)) 279 i2 = keras.layers.Input(shape=(3, 5)) 280 with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'): 281 keras.layers.concatenate([i1, i2], axis=-1) 282 with self.assertRaisesRegexp(ValueError, 'called on a list'): 283 keras.layers.concatenate(i1, axis=-1) 284 with self.assertRaisesRegexp(ValueError, 'called on a list'): 285 keras.layers.concatenate([i1], axis=-1) 286 287 def test_concatenate_with_partial_shape(self): 288 i1 = keras.layers.Input(shape=(5,), batch_size=32) 289 i2 = keras.layers.Input(shape=(5,)) 290 i3 = keras.layers.Input(shape=(4, 5), batch_size=32) 291 i4 = keras.layers.Input(shape=(None,), batch_size=64) 292 i5 = keras.layers.Input(shape=(7,)) 293 294 # Valid case since the i2 has a dynamic batch size. 295 keras.layers.concatenate([i1, i2], axis=-1) 296 297 # Different rank 298 with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'): 299 keras.layers.concatenate([i1, i3], axis=-1) 300 301 # Valid case with partial dimension information 302 keras.layers.concatenate([i1, i4], axis=0) 303 keras.layers.concatenate([i2, i4], axis=0) 304 keras.layers.concatenate([i2, i4], axis=1) 305 keras.layers.concatenate([i1, i2, i4], axis=0) 306 keras.layers.concatenate([i1, i5], axis=1) 307 308 # Mismatch in batch dimension. 309 with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'): 310 keras.layers.concatenate([i1, i4], axis=-1) 311 312 with self.assertRaisesRegexp(ValueError, 'inputs with matching shapes'): 313 keras.layers.concatenate([i1, i2, i4], axis=-1) 314 315 def test_dot_errors(self): 316 i1 = keras.layers.Input(shape=(4, 5)) 317 i2 = keras.layers.Input(shape=(4, 6)) 318 i3 = keras.layers.Input(shape=(4, 6)) 319 with self.assertRaises(ValueError): 320 keras.layers.dot([i1, i2], axes=-1) 321 with self.assertRaises(ValueError): 322 keras.layers.dot(i1, axes=-1) 323 with self.assertRaises(ValueError): 324 keras.layers.dot([i1], axes=-1) 325 with self.assertRaises(ValueError): 326 keras.layers.dot([i1, i2, i3], axes=-1) 327 with self.assertRaises(ValueError): 328 dot = keras.layers.Dot(1) 329 dot.compute_output_shape(1) 330 331 def test_merge_subtract(self): 332 i1 = keras.layers.Input(shape=(4, 5)) 333 i2 = keras.layers.Input(shape=(4, 5)) 334 y = keras.layers.subtract([i1, i2]) 335 self.assertEqual(y.shape.as_list(), [None, 4, 5]) 336 337 # Test invalid use cases 338 i1 = keras.layers.Input(shape=(4, 5)) 339 i2 = keras.layers.Input(shape=(3, 5)) 340 with self.assertRaises(ValueError): 341 keras.layers.subtract([i1, i2]) 342 with self.assertRaises(ValueError): 343 keras.layers.subtract([i1, i1, i1]) 344 345 def test_merge_add_masking(self): 346 i1 = keras.layers.Input(shape=(4, 5)) 347 i2 = keras.layers.Input(shape=(4, 5)) 348 m1 = keras.layers.Masking()(i1) 349 layer = keras.layers.Add() 350 o = layer([m1, i2]) 351 self.assertListEqual(o.shape.as_list(), [None, 4, 5]) 352 mask = layer.output_mask 353 self.assertListEqual(mask.shape.as_list(), [None, 4]) 354 355 def test_merge_add_dynamic_shape(self): 356 i1 = keras.Input(batch_shape=(4, None), dtype='float32') 357 i2 = keras.Input(batch_shape=(4, 5), dtype='float32') 358 layer = keras.layers.Add() 359 o = layer([i1, i2]) 360 self.assertListEqual(o.shape.as_list(), [4, 5]) 361 362 def test_merge_concatenate_masking(self): 363 i1 = keras.layers.Input(shape=(4, 5)) 364 i2 = keras.layers.Input(shape=(4, 5)) 365 m1 = keras.layers.Masking()(i1) 366 layer = keras.layers.Concatenate() 367 o = layer([m1, i2]) 368 self.assertListEqual(o.shape.as_list(), [None, 4, 10]) 369 mask = layer.output_mask 370 self.assertListEqual(mask.shape.as_list(), [None, 4]) 371 372 373if __name__ == '__main__': 374 test.main() 375