1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests specific to `Sequential` model.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python import keras 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.keras import keras_parameterized 28from tensorflow.python.keras import testing_utils 29from tensorflow.python.module import module 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33 34 35class TestSequential(keras_parameterized.TestCase): 36 """Most Sequential model API tests are covered in `training_test.py`. 37 """ 38 39 @keras_parameterized.run_all_keras_modes 40 def test_basic_methods(self): 41 model = keras.models.Sequential() 42 model.add(keras.layers.Dense(1, input_dim=2)) 43 model.add(keras.layers.Dropout(0.3, name='dp')) 44 model.add(keras.layers.Dense(2, kernel_regularizer='l2', 45 kernel_constraint='max_norm')) 46 self.assertEqual(len(model.layers), 3) 47 self.assertEqual(len(model.weights), 2 * 2) 48 self.assertEqual(model.get_layer(name='dp').name, 'dp') 49 50 @keras_parameterized.run_all_keras_modes 51 def test_input_defined_first_layer(self): 52 model = keras.models.Sequential() 53 model.add(keras.Input(shape=(2,), name='input_layer')) 54 model.add(keras.layers.Dense(1)) 55 model.add(keras.layers.Dropout(0.3, name='dp')) 56 model.add(keras.layers.Dense(2, kernel_regularizer='l2', 57 kernel_constraint='max_norm')) 58 self.assertLen(model.layers, 3) 59 self.assertLen(model.weights, 2 * 2) 60 self.assertEqual(model.get_layer(name='dp').name, 'dp') 61 62 @keras_parameterized.run_all_keras_modes 63 def test_single_layer_in_init(self): 64 model = keras.models.Sequential(keras.layers.Dense(1)) 65 self.assertLen(model.layers, 1) 66 67 @keras_parameterized.run_all_keras_modes 68 def test_sequential_pop(self): 69 num_hidden = 5 70 input_dim = 3 71 batch_size = 5 72 num_classes = 2 73 74 model = testing_utils.get_small_sequential_mlp( 75 num_hidden, num_classes, input_dim) 76 model.compile( 77 loss='mse', 78 optimizer='rmsprop', 79 run_eagerly=testing_utils.should_run_eagerly()) 80 x = np.random.random((batch_size, input_dim)) 81 y = np.random.random((batch_size, num_classes)) 82 model.fit(x, y, epochs=1) 83 model.pop() 84 self.assertEqual(len(model.layers), 1) 85 self.assertEqual(model.output_shape, (None, num_hidden)) 86 model.compile( 87 loss='mse', 88 optimizer='rmsprop', 89 run_eagerly=testing_utils.should_run_eagerly()) 90 y = np.random.random((batch_size, num_hidden)) 91 model.fit(x, y, epochs=1) 92 93 # Test popping single-layer model 94 model = keras.models.Sequential() 95 model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) 96 model.pop() 97 self.assertEqual(model.layers, []) 98 self.assertEqual(model.outputs, None) 99 100 # Invalid use case 101 model = keras.models.Sequential() 102 with self.assertRaises(TypeError): 103 model.pop() 104 105 @keras_parameterized.run_all_keras_modes 106 def test_sequential_deferred_build_with_np_arrays(self): 107 num_hidden = 5 108 input_dim = 3 109 batch_size = 5 110 num_classes = 2 111 112 model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes) 113 model.compile( 114 loss='mse', 115 optimizer='rmsprop', 116 metrics=[keras.metrics.CategoricalAccuracy()], 117 run_eagerly=testing_utils.should_run_eagerly()) 118 self.assertEqual(len(model.layers), 2) 119 with self.assertRaisesRegex( 120 ValueError, 'Weights for model .* have not yet been created'): 121 len(model.weights) 122 self.assertFalse(model.built) 123 124 x = np.random.random((batch_size, input_dim)) 125 y = np.random.random((batch_size, num_classes)) 126 model.fit(x, y, epochs=1) 127 self.assertTrue(model.built) 128 self.assertEqual(len(model.weights), 2 * 2) 129 130 @keras_parameterized.run_all_keras_modes 131 def test_sequential_deferred_build_with_dataset_iterators(self): 132 num_hidden = 5 133 input_dim = 3 134 num_classes = 2 135 num_samples = 50 136 steps_per_epoch = 10 137 138 model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes) 139 model.compile( 140 loss='mse', 141 optimizer='rmsprop', 142 metrics=[keras.metrics.CategoricalAccuracy()], 143 run_eagerly=testing_utils.should_run_eagerly()) 144 self.assertEqual(len(model.layers), 2) 145 with self.assertRaisesRegex( 146 ValueError, 'Weights for model .* have not yet been created'): 147 len(model.weights) 148 self.assertFalse(model.built) 149 150 x = array_ops.ones((num_samples, input_dim)) 151 y = array_ops.zeros((num_samples, num_classes)) 152 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) 153 dataset = dataset.repeat(100) 154 dataset = dataset.batch(10) 155 156 model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch) 157 self.assertTrue(model.built) 158 self.assertEqual(len(model.weights), 2 * 2) 159 160 # TODO(kaftan) This test fails w/ run_with_all_keras_modes. File ticket 161 @parameterized.parameters((True,), (False,)) 162 def test_training_and_eval_methods_on_symbolic_tensors(self, deferred): 163 with ops.Graph().as_default(), self.cached_session(): 164 165 def get_model(): 166 if deferred: 167 model = testing_utils.get_small_sequential_mlp(10, 4) 168 else: 169 model = testing_utils.get_small_sequential_mlp(10, 4, input_dim=3) 170 model.compile( 171 optimizer='rmsprop', 172 loss='categorical_crossentropy', 173 metrics=['accuracy']) 174 return model 175 176 inputs = keras.backend.zeros(shape=(10, 3)) 177 targets = keras.backend.zeros(shape=(10, 4)) 178 179 model = get_model() 180 model.fit(inputs, targets, epochs=10, steps_per_epoch=30) 181 182 model = get_model() 183 model.evaluate(inputs, targets, steps=2, verbose=0) 184 185 model = get_model() 186 model.predict(inputs, steps=2) 187 188 model = get_model() 189 model.train_on_batch(inputs, targets) 190 191 model = get_model() 192 model.test_on_batch(inputs, targets) 193 194 model = get_model() 195 model.fit( 196 inputs, 197 targets, 198 epochs=1, 199 steps_per_epoch=2, 200 verbose=0, 201 validation_data=(inputs, targets), 202 validation_steps=2) 203 204 @keras_parameterized.run_all_keras_modes 205 def test_invalid_use_cases(self): 206 # Added objects must be layer instances 207 with self.assertRaises(TypeError): 208 model = keras.models.Sequential() 209 model.add(None) 210 211 @keras_parameterized.run_all_keras_modes 212 def test_nested_sequential_trainability(self): 213 input_dim = 20 214 num_units = 10 215 num_classes = 2 216 217 inner_model = keras.models.Sequential() 218 inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) 219 220 model = keras.models.Sequential() 221 model.add(inner_model) 222 model.add(keras.layers.Dense(num_classes)) 223 224 self.assertEqual(len(model.layers), 2) 225 226 self.assertEqual(len(model.trainable_weights), 4) 227 inner_model.trainable = False 228 self.assertEqual(len(model.trainable_weights), 2) 229 inner_model.trainable = True 230 self.assertEqual(len(model.trainable_weights), 4) 231 232 @keras_parameterized.run_all_keras_modes 233 def test_sequential_update_disabling(self): 234 val_a = np.random.random((10, 4)) 235 val_out = np.random.random((10, 4)) 236 237 model = keras.models.Sequential() 238 model.add(keras.layers.BatchNormalization(input_shape=(4,))) 239 240 model.trainable = False 241 model.compile('sgd', 'mse') 242 243 x1 = model.predict(val_a) 244 model.train_on_batch(val_a, val_out) 245 x2 = model.predict(val_a) 246 self.assertAllClose(x1, x2, atol=1e-7) 247 248 model.trainable = True 249 model.compile('sgd', 'mse') 250 251 model.train_on_batch(val_a, val_out) 252 x2 = model.predict(val_a) 253 assert np.abs(np.sum(x1 - x2)) > 1e-5 254 255 @keras_parameterized.run_all_keras_modes 256 def test_sequential_deferred_build_serialization(self): 257 num_hidden = 5 258 input_dim = 3 259 batch_size = 5 260 num_classes = 2 261 262 model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes) 263 model.compile( 264 loss='mse', 265 optimizer='rmsprop', 266 metrics=[keras.metrics.CategoricalAccuracy()], 267 run_eagerly=testing_utils.should_run_eagerly()) 268 self.assertFalse(model.built) 269 270 x = np.random.random((batch_size, input_dim)) 271 y = np.random.random((batch_size, num_classes)) 272 model.train_on_batch(x, y) 273 self.assertTrue(model.built) 274 275 config = model.get_config() 276 new_model = keras.models.Sequential.from_config(config) 277 new_model.compile( 278 loss='mse', 279 optimizer='rmsprop', 280 metrics=[keras.metrics.CategoricalAccuracy()], 281 run_eagerly=testing_utils.should_run_eagerly()) 282 x = np.random.random((batch_size, input_dim)) 283 y = np.random.random((batch_size, num_classes)) 284 new_model.train_on_batch(x, y) 285 self.assertEqual(len(new_model.layers), 2) 286 self.assertEqual(len(new_model.weights), 4) 287 288 @keras_parameterized.run_all_keras_modes 289 def test_sequential_shape_inference_deferred(self): 290 model = testing_utils.get_small_sequential_mlp(4, 5) 291 output_shape = model.compute_output_shape((None, 7)) 292 self.assertEqual(tuple(output_shape.as_list()), (None, 5)) 293 294 @keras_parameterized.run_all_keras_modes 295 def test_sequential_build_deferred(self): 296 model = testing_utils.get_small_sequential_mlp(4, 5) 297 298 model.build((None, 10)) 299 self.assertTrue(model.built) 300 self.assertEqual(len(model.weights), 4) 301 302 # Test with nested model 303 model = testing_utils.get_small_sequential_mlp(4, 3) 304 inner_model = testing_utils.get_small_sequential_mlp(4, 5) 305 model.add(inner_model) 306 307 model.build((None, 10)) 308 self.assertTrue(model.built) 309 self.assertEqual(len(model.weights), 8) 310 311 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 312 def test_sequential_deferred_manual_build(self): 313 model = testing_utils.get_small_sequential_mlp(4, 5) 314 self.assertFalse(model.built) 315 model(array_ops.zeros([1, 2])) 316 self.assertTrue(model.built) 317 model.compile( 318 'rmsprop', 319 loss='mse', 320 run_eagerly=testing_utils.should_run_eagerly()) 321 model.train_on_batch(np.zeros((1, 2)), np.zeros((1, 5))) 322 323 @keras_parameterized.run_all_keras_modes 324 def test_sequential_nesting(self): 325 model = testing_utils.get_small_sequential_mlp(4, 3) 326 inner_model = testing_utils.get_small_sequential_mlp(4, 5) 327 model.add(inner_model) 328 329 model.compile( 330 loss='mse', 331 optimizer='rmsprop', 332 run_eagerly=testing_utils.should_run_eagerly()) 333 x = np.random.random((2, 6)) 334 y = np.random.random((2, 5)) 335 model.fit(x, y, epochs=1) 336 337 @test_util.run_v1_only('Behavior changed in V2.') 338 def test_variable_names_deferred(self): 339 model = keras.models.Sequential([keras.layers.Dense(3)]) 340 model.add(keras.layers.Dense(2)) 341 model(array_ops.ones([2, 4])) 342 # Note that for regular sequential models (wrapping graph network), 343 # the layers' weights are built 344 # without the model name as prefix (because the Functional API __call__ 345 # reset the name scope). This is fixable, but it would be 346 # backwards incompatible. 347 self.assertEqual( 348 ['sequential/dense/kernel:0', 'sequential/dense/bias:0', 349 'sequential/dense_1/kernel:0', 'sequential/dense_1/bias:0'], 350 [v.name for v in model.variables]) 351 352 @keras_parameterized.run_all_keras_modes 353 def test_input_assumptions_propagation(self): 354 model = keras.models.Sequential() 355 model.add(keras.layers.Dense(1)) 356 if context.executing_eagerly(): 357 with self.assertRaisesRegex(ValueError, 358 'expected min_ndim=2, found ndim=0'): 359 model(1.0) 360 361 @keras_parameterized.run_all_keras_modes 362 def test_string_input(self): 363 seq = keras.Sequential([ 364 keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string), 365 keras.layers.Lambda(lambda x: x[0]) 366 ]) 367 seq.run_eagerly = testing_utils.should_run_eagerly() 368 preds = seq.predict([['tensorflow eager']]) 369 self.assertEqual(preds.shape, (1,)) 370 371 @keras_parameterized.run_all_keras_modes 372 def test_multi_output_layer_not_accepted(self): 373 374 class MultiOutputLayer(keras.layers.Layer): 375 376 def call(self, inputs): 377 return inputs, inputs 378 379 with self.assertRaisesRegex(ValueError, 380 'should have a single output tensor'): 381 keras.Sequential([MultiOutputLayer(input_shape=(3,))]) 382 383 with self.assertRaisesRegex(ValueError, 384 'should have a single output tensor'): 385 keras.Sequential([ 386 keras.layers.Dense(1, input_shape=(3,)), 387 MultiOutputLayer()]) 388 389 # Should also raise error in a deferred build mode 390 with self.assertRaisesRegex(ValueError, 391 'should have a single output tensor'): 392 keras.Sequential([MultiOutputLayer()])(np.zeros((10, 10))) 393 394 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 395 def test_layer_add_after_compile_deferred(self): 396 model = keras.Sequential([keras.layers.Dense(3)]) 397 self.assertFalse(model.built) 398 399 model.compile('adam', loss='mse') 400 model.fit(np.random.random((1, 3)), np.random.random((1, 3))) 401 self.assertTrue(model.built) 402 403 model.add(keras.layers.Dense(3)) 404 405 model.compile('adam', loss='mse') 406 model.fit(np.random.random((1, 3)), np.random.random((1, 3))) 407 self.assertTrue(model.built) 408 409 def test_sequential_layer_tracking(self): 410 """Test that Sequential only tracks layers added in init or `.add`.""" 411 layer = keras.layers.Dense(1) 412 model = keras.Sequential([layer]) 413 self.assertEqual( 414 list(model._flatten_layers(include_self=False, recursive=False))[-1], 415 layer) 416 417 model.a = [keras.layers.Dense(3)] # should not be added to the layers list. 418 self.assertEqual( 419 list(model._flatten_layers(include_self=False, recursive=False))[-1], 420 layer) 421 422 layer2 = keras.layers.Dense(2) 423 model.add(layer2) 424 self.assertEqual( 425 list(model._flatten_layers(include_self=False, recursive=False))[-1], 426 layer2) 427 428 model.a = [keras.layers.Dense(3)] # should not be added to the layers list. 429 self.assertEqual( 430 list(model._flatten_layers(include_self=False, recursive=False))[-1], 431 layer2) 432 433 model.pop() 434 self.assertEqual( 435 list(model._flatten_layers(include_self=False, recursive=False))[-1], 436 layer) 437 438 def test_config_preserves_input_layer(self): 439 model = keras.Sequential([ 440 keras.Input((None,), name='my_embedding_input', dtype='int32'), 441 keras.layers.Embedding(32, 32), 442 keras.layers.Dense(3), 443 ]) 444 config = model.get_config() 445 new_model = keras.Sequential.from_config(config) 446 self.assertTrue(new_model.built) 447 layers = list( 448 new_model._flatten_layers(include_self=False, recursive=False)) 449 self.assertEqual(layers[0].dtype, 'int32') 450 self.assertEqual(layers[0].name, 'my_embedding_input') 451 452 def test_name_unicity(self): 453 model = keras.Sequential() 454 model.add(keras.layers.Dense(3, name='specific_name')) 455 with self.assertRaisesRegex(ValueError, 'should have unique names'): 456 model.add(keras.layers.Dense(3, name='specific_name')) 457 458 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 459 def test_tf_module_call(self): 460 461 class MyModule(module.Module): 462 463 def __init__(self): 464 self.v = variables.Variable(2.) 465 466 def __call__(self, x): 467 return self.v * x 468 469 model = keras.Sequential() 470 model.add(MyModule()) 471 model.compile('sgd', 'mse') 472 x, y = np.ones((10, 1)), np.ones((10, 1)) 473 model.fit(x, y, batch_size=2) 474 self.assertLen(model.trainable_variables, 1) 475 476 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 477 def test_tf_module_training(self): 478 479 class MyModule(module.Module): 480 481 def __init__(self): 482 self.v = variables.Variable(2.) 483 484 def call(self, x, training=None): 485 # training should be set by Sequential. 486 assert training is not None 487 return self.v * x 488 489 model = keras.Sequential() 490 model.add(MyModule()) 491 model.compile('sgd', 'mse') 492 x, y = np.ones((10, 1)), np.ones((10, 1)) 493 model.fit(x, y, batch_size=2) 494 self.assertLen(model.trainable_variables, 1) 495 496 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 497 def test_tf_module_error(self): 498 499 class MyModule(module.Module): 500 501 def __init__(self): 502 self.v = variables.Variable(2.) 503 504 model = keras.Sequential() 505 with self.assertRaisesRegex(ValueError, 'is not defined'): 506 model.add(MyModule()) 507 508 509class TestSequentialEagerIntegration(keras_parameterized.TestCase): 510 511 @keras_parameterized.run_all_keras_modes 512 def test_defun_on_call(self): 513 # Check that one can subclass Sequential and place the `call` in a `defun`. 514 515 class MySequential(keras.Sequential): 516 517 def __init__(self, name=None): 518 super(MySequential, self).__init__(name=name) 519 self.call = def_function.function(self.call) 520 521 model = MySequential() 522 model.add(keras.layers.Dense(4, activation='relu')) 523 model.add(keras.layers.Dense(5, activation='softmax')) 524 525 model.compile( 526 loss='mse', 527 optimizer='rmsprop', 528 run_eagerly=testing_utils.should_run_eagerly()) 529 530 x = np.random.random((2, 6)) 531 y = np.random.random((2, 5)) 532 model.fit(x, y, epochs=1) 533 534 @keras_parameterized.run_all_keras_modes 535 def test_build_before_fit(self): 536 # Fix for b/112433577 537 model = testing_utils.get_small_sequential_mlp(4, 5) 538 model.compile( 539 loss='mse', 540 optimizer='rmsprop', 541 run_eagerly=testing_utils.should_run_eagerly()) 542 543 model.build((None, 6)) 544 545 x = np.random.random((2, 6)) 546 y = np.random.random((2, 5)) 547 model.fit(x, y, epochs=1) 548 549 @keras_parameterized.run_all_keras_modes 550 def test_build_empty_network(self): 551 x = np.random.random((2, 6)) 552 y = np.random.random((2, 5)) 553 model = keras.Sequential() 554 555 # Make sure an empty sequential model can still work with build(). 556 model.build((None, 6)) 557 self.assertTrue(model.built) 558 559 model.add(keras.layers.Dense(5, input_shape=(6,))) 560 561 model.compile( 562 loss='mse', 563 optimizer='rmsprop', 564 run_eagerly=testing_utils.should_run_eagerly()) 565 model.fit(x, y) 566 567 model.pop() 568 self.assertFalse(model.built) 569 570 model.build((None, 6)) 571 self.assertTrue(model.built) 572 573 574if __name__ == '__main__': 575 test.main() 576