1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for training routines.""" 16 17import io 18import sys 19 20import numpy as np 21 22from tensorflow.python import keras 23from tensorflow.python.data.experimental.ops import cardinality 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import ops 26from tensorflow.python.keras import callbacks 27from tensorflow.python.keras import keras_parameterized 28from tensorflow.python.keras import metrics as metrics_module 29from tensorflow.python.keras import testing_utils 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import test 32from tensorflow.python.platform import tf_logging as logging 33 34 35class BatchCounterCallback(callbacks.Callback): 36 37 def __init__(self): 38 self.batch_begin_count = 0 39 self.batch_end_count = 0 40 41 def on_batch_begin(self, *args, **kwargs): 42 self.batch_begin_count += 1 43 44 def on_batch_end(self, *args, **kwargs): 45 self.batch_end_count += 1 46 47 48class TestTrainingWithDataset(keras_parameterized.TestCase): 49 50 @keras_parameterized.run_with_all_model_types 51 @keras_parameterized.run_all_keras_modes 52 def test_calling_model_on_same_dataset(self): 53 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 54 optimizer = 'rmsprop' 55 loss = 'mse' 56 metrics = ['mae'] 57 model.compile( 58 optimizer, 59 loss, 60 metrics=metrics, 61 run_eagerly=testing_utils.should_run_eagerly()) 62 63 inputs = np.zeros((10, 3), np.float32) 64 targets = np.zeros((10, 4), np.float32) 65 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 66 dataset = dataset.repeat(100) 67 dataset = dataset.batch(10) 68 69 # Call fit with validation data 70 model.fit( 71 dataset, 72 epochs=1, 73 steps_per_epoch=2, 74 verbose=0, 75 validation_data=dataset, 76 validation_steps=2) 77 model.fit( 78 dataset, 79 epochs=1, 80 steps_per_epoch=2, 81 verbose=0, 82 validation_data=dataset, 83 validation_steps=2) 84 85 @keras_parameterized.run_with_all_model_types 86 @keras_parameterized.run_all_keras_modes 87 def test_training_and_eval_methods_on_dataset(self): 88 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 89 optimizer = 'rmsprop' 90 loss = 'mse' 91 metrics = ['mae', metrics_module.CategoricalAccuracy()] 92 model.compile( 93 optimizer, 94 loss, 95 metrics=metrics, 96 run_eagerly=testing_utils.should_run_eagerly()) 97 98 inputs = np.zeros((10, 3), np.float32) 99 targets = np.zeros((10, 4), np.float32) 100 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 101 dataset = dataset.repeat() # Infinite dataset. 102 dataset = dataset.batch(10) 103 104 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 105 model.evaluate(dataset, steps=2, verbose=1) 106 model.predict(dataset, steps=2) 107 108 # Test with validation data 109 model.fit( 110 dataset, 111 epochs=1, 112 steps_per_epoch=2, 113 verbose=0, 114 validation_data=dataset, 115 validation_steps=2) 116 117 # Test with validation split 118 with self.assertRaises(ValueError): 119 model.fit( 120 dataset, 121 epochs=1, 122 steps_per_epoch=2, 123 verbose=0, 124 validation_split=0.5, 125 validation_steps=2) 126 127 # Test with sample weight. 128 sample_weight = np.random.random((10,)) 129 with self.assertRaisesRegex( 130 ValueError, r'`sample_weight` argument is not supported .+dataset'): 131 model.fit( 132 dataset, 133 epochs=1, 134 steps_per_epoch=2, 135 verbose=0, 136 sample_weight=sample_weight) 137 138 with self.assertRaisesRegex( 139 ValueError, '(you should not specify a target)|' 140 '(`y` argument is not supported when using dataset as input.)'): 141 model.fit(dataset, dataset, epochs=1, steps_per_epoch=2, verbose=0) 142 143 # With an infinite dataset, `steps_per_epoch`/`steps` argument is required. 144 with self.assertRaises(ValueError): 145 model.fit(dataset, epochs=1, verbose=0) 146 with self.assertRaises(ValueError): 147 model.evaluate(dataset, verbose=0) 148 with self.assertRaises(ValueError): 149 model.predict(dataset, verbose=0) 150 151 @keras_parameterized.run_with_all_model_types(exclude_models='sequential') 152 @keras_parameterized.run_all_keras_modes 153 def test_training_and_eval_methods_on_multi_input_output_dataset(self): 154 input_a = keras.layers.Input(shape=(3,), name='input_1') 155 input_b = keras.layers.Input(shape=(3,), name='input_2') 156 dense = keras.layers.Dense(4, name='dense') 157 dropout = keras.layers.Dropout(0.5, name='dropout') 158 branch_a = [input_a, dense] 159 branch_b = [input_b, dense, dropout] 160 161 model = testing_utils.get_multi_io_model(branch_a, branch_b) 162 model.compile( 163 optimizer='rmsprop', 164 loss='mse', 165 run_eagerly=testing_utils.should_run_eagerly()) 166 167 input_a_np = np.random.random((10, 3)).astype(dtype=np.float32) 168 input_b_np = np.random.random((10, 3)).astype(dtype=np.float32) 169 output_d_np = np.random.random((10, 4)).astype(dtype=np.float32) 170 output_e_np = np.random.random((10, 4)).astype(dtype=np.float32) 171 172 # Test with tuples 173 dataset_tuple = dataset_ops.Dataset.from_tensor_slices( 174 ((input_a_np, input_b_np), (output_d_np, output_e_np))) 175 dataset_tuple = dataset_tuple.repeat(100) 176 dataset_tuple = dataset_tuple.batch(10) 177 178 model.fit(dataset_tuple, epochs=1, steps_per_epoch=2, verbose=1) 179 model.evaluate(dataset_tuple, steps=2, verbose=1) 180 181 # Test with dict 182 input_dict = {'input_1': input_a_np, 'input_2': input_b_np} 183 if testing_utils.get_model_type() == 'subclass': 184 output_dict = {'output_1': output_d_np, 'output_2': output_e_np} 185 else: 186 output_dict = {'dense': output_d_np, 'dropout': output_e_np} 187 188 dataset_dict = dataset_ops.Dataset.from_tensor_slices( 189 (input_dict, output_dict)) 190 dataset_dict = dataset_dict.repeat(100) 191 dataset_dict = dataset_dict.batch(10) 192 193 model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) 194 model.evaluate(dataset_dict, steps=2, verbose=1) 195 196 predict_dataset_dict = dataset_ops.Dataset.from_tensor_slices(input_dict) 197 predict_dataset_dict = predict_dataset_dict.repeat(100) 198 predict_dataset_dict = predict_dataset_dict.batch(10) 199 model.predict(predict_dataset_dict, steps=1) 200 201 @keras_parameterized.run_with_all_model_types 202 @keras_parameterized.run_all_keras_modes 203 def test_dataset_with_sample_weights(self): 204 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 205 optimizer = 'rmsprop' 206 loss = 'mse' 207 metrics = ['mae', metrics_module.CategoricalAccuracy()] 208 model.compile( 209 optimizer, 210 loss, 211 metrics=metrics, 212 run_eagerly=testing_utils.should_run_eagerly()) 213 214 inputs = np.zeros((10, 3), np.float32) 215 targets = np.zeros((10, 4), np.float32) 216 sample_weights = np.ones((10), np.float32) 217 dataset = dataset_ops.Dataset.from_tensor_slices( 218 (inputs, targets, sample_weights)) 219 dataset = dataset.repeat(100) 220 dataset = dataset.batch(10) 221 222 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 223 model.evaluate(dataset, steps=2, verbose=1) 224 model.predict(dataset, steps=2) 225 226 @keras_parameterized.run_with_all_model_types 227 @keras_parameterized.run_all_keras_modes 228 def test_dataset_with_sample_weights_correctness(self): 229 x = keras.layers.Input(shape=(1,), name='input') 230 y = keras.layers.Dense( 231 1, kernel_initializer='ones', bias_initializer='zeros', name='dense')( 232 x) 233 model = keras.Model(x, y) 234 optimizer = 'rmsprop' 235 loss = 'mse' 236 model.compile(optimizer, loss) 237 inputs = np.array([[0], [1], [2], [3]], np.float32) 238 targets = np.array([[2], [4], [6], [8]], np.float32) 239 sample_weights = np.array([0.25, 0.5, 0.75, 1], np.float32) 240 ds = dataset_ops.Dataset.from_tensor_slices( 241 (inputs, targets, sample_weights)).batch(2) 242 result = model.evaluate(ds, verbose=1) 243 # The per sample loss is multipled by the corresponding sample weight. The 244 # average of these weighted losses is the return value of the `evaluate` 245 # call. For example, in the test above the average weighted loss is 246 # calculated in the following manner: 247 # ((2-0)^2) * 0.25 + ((4-1)^2) * 0.5 + ((6-2)^2 * 0.75) + ((8-3)^2 * 1) 248 # equals 42.5 / 4 = 10.625 249 self.assertEqual(result, 10.625) 250 251 @keras_parameterized.run_with_all_model_types 252 @keras_parameterized.run_all_keras_modes 253 def test_dataset_with_sparse_labels(self): 254 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 255 optimizer = 'rmsprop' 256 model.compile( 257 optimizer, 258 loss='sparse_categorical_crossentropy', 259 run_eagerly=testing_utils.should_run_eagerly()) 260 261 inputs = np.zeros((10, 3), dtype=np.float32) 262 targets = np.random.randint(0, 4, size=10, dtype=np.int32) 263 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 264 dataset = dataset.repeat(100) 265 dataset = dataset.batch(10) 266 267 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 268 269 @keras_parameterized.run_all_keras_modes 270 def test_dataset_fit_correctness(self): 271 272 class SumLayer(keras.layers.Layer): 273 274 def build(self, _): 275 self.w = self.add_weight('w', ()) 276 277 def call(self, inputs): 278 return keras.backend.sum(inputs, axis=1, keepdims=True) + self.w * 0 279 280 model = keras.Sequential([SumLayer(input_shape=(2,))]) 281 model.compile( 282 'rmsprop', loss='mae', run_eagerly=testing_utils.should_run_eagerly()) 283 284 inputs = np.zeros((40, 2), dtype=np.float32) 285 inputs[10:20, :] = 2 286 inputs[20:30, :] = 1 287 inputs[30:, :] = 4 288 targets = np.zeros((40, 1), dtype=np.float32) 289 290 # Test correctness with `steps_per_epoch`. 291 train_dataset = dataset_ops.Dataset.from_tensor_slices( 292 (inputs, targets)).batch(10) 293 val_dataset = dataset_ops.Dataset.from_tensor_slices( 294 (inputs, targets)).batch(10) 295 history = model.fit( 296 train_dataset, 297 epochs=2, 298 steps_per_epoch=2, 299 verbose=1, 300 validation_data=val_dataset, 301 validation_steps=2) 302 self.assertAllClose(history.history['loss'], 303 [inputs[:20].sum() / 20, inputs[20:].sum() / 20]) 304 # The validation dataset will be reset at the end of each validation run. 305 self.assertAllClose(history.history['val_loss'], 306 [inputs[:20].sum() / 20, inputs[:20].sum() / 20]) 307 308 # Test correctness with dataset reset. 309 train_dataset = dataset_ops.Dataset.from_tensor_slices( 310 (inputs, targets)).batch(10) 311 val_dataset = dataset_ops.Dataset.from_tensor_slices( 312 (inputs, targets)).batch(10) 313 history = model.fit( 314 train_dataset, epochs=2, verbose=1, validation_data=val_dataset) 315 self.assertAllClose( 316 history.history['loss'], 317 [inputs.sum() / 40, inputs.sum() / 40]) 318 self.assertAllClose( 319 history.history['val_loss'], 320 [inputs.sum() / 40, inputs.sum() / 40]) 321 322 def test_dataset_input_shape_validation(self): 323 with ops.get_default_graph().as_default(), self.cached_session(): 324 model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) 325 model.compile(optimizer='rmsprop', loss='mse') 326 327 # User forgets to batch the dataset 328 inputs = np.zeros((10, 3)) 329 targets = np.zeros((10, 4)) 330 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 331 dataset = dataset.repeat(100) 332 333 with self.assertRaisesRegex( 334 ValueError, 335 r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)' 336 ): 337 model.train_on_batch(dataset) 338 339 # Wrong input shape 340 inputs = np.zeros((10, 5)) 341 targets = np.zeros((10, 4)) 342 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 343 dataset = dataset.repeat(100) 344 dataset = dataset.batch(10) 345 346 with self.assertRaisesRegex(ValueError, 347 r'expected (.*?) to have shape \(3,\)'): 348 model.train_on_batch(dataset) 349 350 @keras_parameterized.run_with_all_model_types 351 @keras_parameterized.run_all_keras_modes 352 def test_finite_dataset_known_cardinality_no_steps_arg(self): 353 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 354 model.compile( 355 'rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) 356 357 inputs = np.zeros((100, 3), dtype=np.float32) 358 targets = np.random.randint(0, 4, size=100, dtype=np.int32) 359 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 360 dataset = dataset.batch(10) 361 362 batch_counter = BatchCounterCallback() 363 history = model.fit(dataset, epochs=2, verbose=1, callbacks=[batch_counter]) 364 365 self.assertLen(history.history['loss'], 2) 366 self.assertEqual(batch_counter.batch_end_count, 20) 367 model.evaluate(dataset) 368 out = model.predict(dataset) 369 self.assertEqual(out.shape[0], 100) 370 371 @keras_parameterized.run_with_all_model_types 372 @keras_parameterized.run_all_keras_modes 373 def test_finite_dataset_unknown_cardinality_no_steps_arg(self): 374 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 375 model.compile( 376 'rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) 377 378 inputs = np.zeros((100, 3), dtype=np.float32) 379 targets = np.random.randint(0, 4, size=100, dtype=np.int32) 380 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 381 dataset = dataset.filter(lambda x, y: True).batch(10) 382 self.assertEqual( 383 keras.backend.get_value(cardinality.cardinality(dataset)), 384 cardinality.UNKNOWN) 385 386 batch_counter = BatchCounterCallback() 387 history = model.fit(dataset, epochs=2, verbose=1, callbacks=[batch_counter]) 388 389 self.assertLen(history.history['loss'], 2) 390 self.assertEqual(batch_counter.batch_end_count, 20) 391 model.evaluate(dataset) 392 out = model.predict(dataset) 393 self.assertEqual(out.shape[0], 100) 394 395 @keras_parameterized.run_with_all_model_types 396 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 397 def test_finite_dataset_unknown_cardinality_no_step_with_train_and_val(self): 398 399 class CaptureStdout(object): 400 401 def __enter__(self): 402 self._stdout = sys.stdout 403 string_io = io.StringIO() 404 sys.stdout = string_io 405 self._stringio = string_io 406 return self 407 408 def __exit__(self, *args): 409 self.output = self._stringio.getvalue() 410 sys.stdout = self._stdout 411 412 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 413 model.compile( 414 'rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) 415 416 inputs = np.zeros((100, 3), dtype=np.float32) 417 targets = np.random.randint(0, 4, size=100, dtype=np.int32) 418 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 419 dataset = dataset.filter(lambda x, y: True).batch(10) 420 self.assertEqual( 421 keras.backend.get_value(cardinality.cardinality(dataset)), 422 cardinality.UNKNOWN) 423 424 batch_counter = BatchCounterCallback() 425 with CaptureStdout() as capture: 426 history = model.fit( 427 dataset, 428 epochs=2, 429 callbacks=[batch_counter], 430 validation_data=dataset.take(3)) 431 432 lines = capture.output.splitlines() 433 434 self.assertIn('10/10', lines[-1]) 435 436 self.assertLen(history.history['loss'], 2) 437 self.assertEqual(batch_counter.batch_begin_count, 21) 438 self.assertEqual(batch_counter.batch_end_count, 20) 439 model.evaluate(dataset) 440 out = model.predict(dataset) 441 self.assertEqual(out.shape[0], 100) 442 443 @keras_parameterized.run_with_all_model_types 444 @keras_parameterized.run_all_keras_modes 445 def test_finite_dataset_unknown_cardinality_out_of_data(self): 446 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 447 model.compile( 448 'rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) 449 450 inputs = np.zeros((100, 3), dtype=np.float32) 451 targets = np.random.randint(0, 4, size=100, dtype=np.int32) 452 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 453 dataset = dataset.filter(lambda x, y: True).batch(10) 454 self.assertEqual( 455 keras.backend.get_value(cardinality.cardinality(dataset)), 456 cardinality.UNKNOWN) 457 458 batch_counter = BatchCounterCallback() 459 with test.mock.patch.object(logging, 'warning') as mock_log: 460 # steps_per_epoch (200) is greater than the dataset size (100). As this is 461 # unexpected, training will stop and not make it to the second epoch. 462 history = model.fit( 463 dataset, 464 epochs=2, 465 verbose=1, 466 callbacks=[batch_counter], 467 steps_per_epoch=200) 468 self.assertIn('ran out of data; interrupting training.', 469 str(mock_log.call_args)) 470 self.assertIn( 471 'can generate at least ' 472 '`steps_per_epoch * epochs` batches (in this case, 400 batches). ' 473 'You may need to use the repeat() function when ' 474 'building your dataset.', str(mock_log.call_args)) 475 476 self.assertLen(history.history['loss'], 1) 477 self.assertEqual(batch_counter.batch_end_count, 10) 478 model.evaluate(dataset) 479 out = model.predict(dataset) 480 self.assertEqual(out.shape[0], 100) 481 482 @keras_parameterized.run_all_keras_modes 483 def test_with_external_loss(self): 484 inp = keras.Input(shape=(4,), name='inp1') 485 out = keras.layers.Dense(2)(inp) 486 model = keras.Model(inp, out) 487 model.add_loss(math_ops.reduce_mean(out)) 488 model.compile('rmsprop') 489 x = np.ones((10, 4)) 490 491 # dataset contains only features, no labels. 492 dataset = dataset_ops.Dataset.from_tensor_slices(x).repeat(10).batch(10) 493 model.fit(dataset) 494 495 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 496 def test_train_eval_with_steps(self): 497 # See b/142880049 for more details. 498 inp = keras.Input(shape=(4,), name='inp1') 499 out = keras.layers.Dense(2)(inp) 500 model = keras.Model(inp, out) 501 model.compile( 502 'rmsprop', loss='mse', run_eagerly=testing_utils.should_run_eagerly()) 503 504 inputs = np.zeros((100, 4), dtype=np.float32) 505 targets = np.random.randint(0, 2, size=100, dtype=np.int32) 506 training_ds = dataset_ops.Dataset.from_tensor_slices( 507 (inputs, targets)).repeat().batch(10) 508 509 # Create eval dataset with generator, so that dataset won't contain the 510 # overall size metadata. Without eval_steps, we expect to run through all 511 # the data in this dataset every epoch. 512 def gen(): 513 for _ in range(100): 514 yield (np.zeros(4, dtype=np.float32), 515 np.random.randint(0, 2, size=1, dtype=np.int32)) 516 517 eval_ds = dataset_ops.Dataset.from_generator( 518 generator=gen, 519 output_types=('float64', 'int32'), 520 output_shapes=([4], [1])).batch(100) 521 batch_counter = BatchCounterCallback() 522 523 model.fit( 524 training_ds, 525 steps_per_epoch=10, 526 epochs=10, 527 validation_data=eval_ds, 528 callbacks=[batch_counter]) 529 530 # Expect 10 batch from training per epoch. 531 self.assertEqual(batch_counter.batch_end_count, 100) 532 533 534class TestMetricsWithDatasets(keras_parameterized.TestCase): 535 536 @keras_parameterized.run_with_all_model_types 537 @keras_parameterized.run_all_keras_modes 538 def test_metrics_correctness_with_dataset(self): 539 layers = [ 540 keras.layers.Dense( 541 8, activation='relu', input_dim=4, kernel_initializer='ones'), 542 keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones') 543 ] 544 545 model = testing_utils.get_model_from_layers(layers, (4,)) 546 547 model.compile( 548 loss='binary_crossentropy', 549 metrics=['accuracy', metrics_module.BinaryAccuracy()], 550 optimizer='rmsprop', 551 run_eagerly=testing_utils.should_run_eagerly()) 552 553 np.random.seed(123) 554 x = np.random.randint(10, size=(100, 4)).astype(np.float32) 555 y = np.random.randint(2, size=(100, 1)).astype(np.float32) 556 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) 557 dataset = dataset.batch(10) 558 outs = model.evaluate(dataset, steps=10) 559 self.assertEqual(np.around(outs[1], decimals=1), 0.5) 560 self.assertEqual(np.around(outs[2], decimals=1), 0.5) 561 562 y = np.zeros((100, 1), dtype=np.float32) 563 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) 564 dataset = dataset.repeat(100) 565 dataset = dataset.batch(10) 566 outs = model.evaluate(dataset, steps=10) 567 self.assertEqual(outs[1], 0.) 568 self.assertEqual(outs[2], 0.) 569 570 571if __name__ == '__main__': 572 test.main() 573