1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for training routines.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.eager import context 29from tensorflow.python.keras import combinations 30from tensorflow.python.keras import keras_parameterized 31from tensorflow.python.keras import layers as layers_module 32from tensorflow.python.keras import losses 33from tensorflow.python.keras import metrics as metrics_module 34from tensorflow.python.keras import testing_utils 35from tensorflow.python.keras.engine import input_layer 36from tensorflow.python.keras.engine import training 37from tensorflow.python.keras.engine import training_generator_v1 38from tensorflow.python.keras.optimizer_v2 import rmsprop 39from tensorflow.python.keras.utils import data_utils 40from tensorflow.python.platform import test 41from tensorflow.python.util import nest 42 43 44def custom_generator(mode=2): 45 batch_size = 10 46 num_samples = 50 47 arr_data = np.random.random((num_samples, 2)) 48 arr_labels = np.random.random((num_samples, 4)) 49 arr_weights = np.random.random((num_samples,)) 50 i = 0 51 while True: 52 batch_index = i * batch_size % num_samples 53 i += 1 54 start = batch_index 55 end = start + batch_size 56 x = arr_data[start: end] 57 y = arr_labels[start: end] 58 w = arr_weights[start: end] 59 if mode == 1: 60 yield x 61 elif mode == 2: 62 yield x, y 63 else: 64 yield x, y, w 65 66 67def custom_generator_changing_batch_size(mode=2): 68 batch_size = 10 69 cur_batch_size = 11 70 num_samples = 50 71 arr_data = np.random.random((num_samples, 2)) 72 arr_labels = np.random.random((num_samples, 4)) 73 arr_weights = np.random.random((num_samples,)) 74 i = 0 75 while True: 76 if cur_batch_size > 1: 77 cur_batch_size -= 1 78 batch_index = i * batch_size % num_samples 79 i += 1 80 start = batch_index 81 end = start + cur_batch_size 82 x = arr_data[start: end] 83 y = arr_labels[start: end] 84 w = arr_weights[start: end] 85 if mode == 1: 86 yield x 87 elif mode == 2: 88 yield x, y 89 else: 90 yield x, y, w 91 92custom_generator_threads = data_utils.threadsafe_generator(custom_generator) 93 94 95class TestGeneratorMethods(keras_parameterized.TestCase): 96 97 @keras_parameterized.run_with_all_model_types 98 @keras_parameterized.run_all_keras_modes 99 @data_utils.dont_use_multiprocessing_pool 100 def test_fit_generator_method(self): 101 model = testing_utils.get_small_mlp( 102 num_hidden=3, num_classes=4, input_dim=2) 103 model.compile( 104 loss='mse', 105 optimizer=rmsprop.RMSprop(1e-3), 106 metrics=['mae', metrics_module.CategoricalAccuracy()]) 107 108 model.fit_generator(custom_generator_threads(), 109 steps_per_epoch=5, 110 epochs=1, 111 verbose=1, 112 max_queue_size=10, 113 workers=4, 114 use_multiprocessing=True) 115 model.fit_generator(custom_generator(), 116 steps_per_epoch=5, 117 epochs=1, 118 verbose=1, 119 max_queue_size=10, 120 use_multiprocessing=False) 121 model.fit_generator(custom_generator(), 122 steps_per_epoch=5, 123 epochs=1, 124 verbose=1, 125 max_queue_size=10, 126 use_multiprocessing=False, 127 validation_data=custom_generator(), 128 validation_steps=10) 129 model.fit_generator(custom_generator(), 130 steps_per_epoch=5, 131 validation_data=custom_generator(), 132 validation_steps=1, 133 workers=0) 134 135 @keras_parameterized.run_with_all_model_types 136 @keras_parameterized.run_all_keras_modes 137 @data_utils.dont_use_multiprocessing_pool 138 def test_evaluate_generator_method(self): 139 model = testing_utils.get_small_mlp( 140 num_hidden=3, num_classes=4, input_dim=2) 141 model.compile( 142 loss='mse', 143 optimizer=rmsprop.RMSprop(1e-3), 144 metrics=['mae', metrics_module.CategoricalAccuracy()], 145 run_eagerly=testing_utils.should_run_eagerly()) 146 147 model.evaluate_generator(custom_generator_threads(), 148 steps=5, 149 max_queue_size=10, 150 workers=2, 151 verbose=1, 152 use_multiprocessing=True) 153 model.evaluate_generator(custom_generator(), 154 steps=5, 155 max_queue_size=10, 156 use_multiprocessing=False) 157 model.evaluate_generator(custom_generator(), 158 steps=5, 159 max_queue_size=10, 160 use_multiprocessing=False, 161 workers=0) 162 163 @keras_parameterized.run_with_all_model_types 164 @keras_parameterized.run_all_keras_modes 165 @data_utils.dont_use_multiprocessing_pool 166 def test_predict_generator_method(self): 167 model = testing_utils.get_small_mlp( 168 num_hidden=3, num_classes=4, input_dim=2) 169 model.run_eagerly = testing_utils.should_run_eagerly() 170 171 model.predict_generator(custom_generator_threads(), 172 steps=5, 173 max_queue_size=10, 174 workers=2, 175 use_multiprocessing=True) 176 model.predict_generator(custom_generator(), 177 steps=5, 178 max_queue_size=10, 179 use_multiprocessing=False) 180 model.predict_generator(custom_generator(), 181 steps=5, 182 max_queue_size=10, 183 workers=0) 184 # Test generator with just inputs (no targets) 185 model.predict_generator(custom_generator_threads(mode=1), 186 steps=5, 187 max_queue_size=10, 188 workers=2, 189 use_multiprocessing=True) 190 model.predict_generator(custom_generator(mode=1), 191 steps=5, 192 max_queue_size=10, 193 use_multiprocessing=False) 194 model.predict_generator(custom_generator(mode=1), 195 steps=5, 196 max_queue_size=10, 197 workers=0) 198 199 @keras_parameterized.run_with_all_model_types 200 @keras_parameterized.run_all_keras_modes 201 def test_generator_methods_with_sample_weights(self): 202 model = testing_utils.get_small_mlp( 203 num_hidden=3, num_classes=4, input_dim=2) 204 model.compile( 205 loss='mse', 206 optimizer=rmsprop.RMSprop(1e-3), 207 metrics=['mae', metrics_module.CategoricalAccuracy()], 208 run_eagerly=testing_utils.should_run_eagerly()) 209 210 model.fit_generator(custom_generator(mode=3), 211 steps_per_epoch=5, 212 epochs=1, 213 verbose=1, 214 max_queue_size=10, 215 use_multiprocessing=False) 216 model.fit_generator(custom_generator(mode=3), 217 steps_per_epoch=5, 218 epochs=1, 219 verbose=1, 220 max_queue_size=10, 221 use_multiprocessing=False, 222 validation_data=custom_generator(mode=3), 223 validation_steps=10) 224 model.predict_generator(custom_generator(mode=3), 225 steps=5, 226 max_queue_size=10, 227 use_multiprocessing=False) 228 model.evaluate_generator(custom_generator(mode=3), 229 steps=5, 230 max_queue_size=10, 231 use_multiprocessing=False) 232 233 @keras_parameterized.run_with_all_model_types 234 @keras_parameterized.run_all_keras_modes 235 def test_generator_methods_invalid_use_case(self): 236 def invalid_generator(): 237 while 1: 238 yield (0, 0, 0, 0) 239 240 model = testing_utils.get_small_mlp( 241 num_hidden=3, num_classes=4, input_dim=2) 242 model.compile( 243 loss='mse', 244 optimizer=rmsprop.RMSprop(1e-3), 245 run_eagerly=testing_utils.should_run_eagerly()) 246 247 with self.assertRaises(ValueError): 248 model.fit_generator(invalid_generator(), 249 steps_per_epoch=5, 250 epochs=1, 251 verbose=1, 252 max_queue_size=10, 253 use_multiprocessing=False) 254 with self.assertRaises(ValueError): 255 model.fit_generator(custom_generator(), 256 steps_per_epoch=5, 257 epochs=1, 258 verbose=1, 259 max_queue_size=10, 260 use_multiprocessing=False, 261 validation_data=invalid_generator(), 262 validation_steps=10) 263 with self.assertRaises(ValueError): 264 model.predict_generator(invalid_generator(), 265 steps=5, 266 max_queue_size=10, 267 use_multiprocessing=False) 268 with self.assertRaises(ValueError): 269 model.evaluate_generator(invalid_generator(), 270 steps=5, 271 max_queue_size=10, 272 use_multiprocessing=False) 273 274 @keras_parameterized.run_with_all_model_types 275 @keras_parameterized.run_all_keras_modes 276 def test_generator_input_to_fit_eval_predict(self): 277 val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 278 279 def ones_generator(): 280 while True: 281 yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 282 283 model = testing_utils.get_small_mlp( 284 num_hidden=10, num_classes=1, input_dim=10) 285 286 model.compile( 287 rmsprop.RMSprop(0.001), 288 'binary_crossentropy', 289 run_eagerly=testing_utils.should_run_eagerly()) 290 model.fit( 291 ones_generator(), 292 steps_per_epoch=2, 293 validation_data=val_data, 294 epochs=2) 295 model.evaluate(ones_generator(), steps=2) 296 model.predict(ones_generator(), steps=2) 297 298 # Test with a changing batch size 299 model = testing_utils.get_small_mlp( 300 num_hidden=3, num_classes=4, input_dim=2) 301 model.compile( 302 loss='mse', 303 optimizer=rmsprop.RMSprop(1e-3), 304 metrics=['mae', metrics_module.CategoricalAccuracy()]) 305 model.fit_generator(custom_generator_changing_batch_size(), 306 steps_per_epoch=5, 307 epochs=1, 308 verbose=1, 309 max_queue_size=10, 310 use_multiprocessing=False) 311 model.fit_generator(custom_generator_changing_batch_size(), 312 steps_per_epoch=5, 313 epochs=1, 314 verbose=1, 315 max_queue_size=10, 316 use_multiprocessing=False, 317 validation_data=custom_generator_changing_batch_size(), 318 validation_steps=10) 319 320 model.fit( 321 custom_generator_changing_batch_size(), 322 steps_per_epoch=5, 323 validation_data=custom_generator_changing_batch_size(), 324 validation_steps=10, 325 epochs=2) 326 model.evaluate(custom_generator_changing_batch_size(), steps=5) 327 model.predict(custom_generator_changing_batch_size(), steps=5) 328 329 @keras_parameterized.run_with_all_model_types 330 @keras_parameterized.run_all_keras_modes 331 @data_utils.dont_use_multiprocessing_pool 332 def test_generator_dynamic_shapes(self): 333 334 x = [ 335 'I think juice is great', 336 'unknown is the best language since slicedbread', 337 'a a a a a a a', 338 'matmul' 339 'Yaks are also quite nice', 340 ] 341 y = [1, 0, 0, 1, 1] 342 343 vocab = { 344 word: i + 1 for i, word in 345 enumerate( 346 sorted(set(itertools.chain(*[i.split() for i in x])))) 347 } 348 349 def data_gen(batch_size=2): 350 np.random.seed(0) 351 data = list(zip(x, y)) * 10 352 np.random.shuffle(data) 353 354 def pack_and_pad(queue): 355 x = [[vocab[j] for j in i[0].split()] for i in queue] 356 pad_len = max(len(i) for i in x) 357 x = np.array([i + [0] * (pad_len - len(i)) for i in x]) 358 y = np.array([i[1] for i in queue]) 359 del queue[:] 360 return x, y[:, np.newaxis] 361 362 queue = [] 363 for i, element in enumerate(data): 364 queue.append(element) 365 if not (i + 1) % batch_size: 366 yield pack_and_pad(queue) 367 368 if queue: 369 # Last partial batch 370 yield pack_and_pad(queue) 371 372 model = testing_utils.get_model_from_layers([ 373 layers_module.Embedding(input_dim=len(vocab) + 1, output_dim=4), 374 layers_module.SimpleRNN(units=1), 375 layers_module.Activation('sigmoid') 376 ], 377 input_shape=(None,)) 378 379 model.compile(loss=losses.binary_crossentropy, optimizer='sgd') 380 model.fit(data_gen(), epochs=1, steps_per_epoch=5) 381 382 383class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase): 384 385 @keras_parameterized.run_with_all_model_types 386 @keras_parameterized.run_all_keras_modes 387 @data_utils.dont_use_multiprocessing_pool 388 def test_training_with_sequences(self): 389 390 class DummySequence(data_utils.Sequence): 391 392 def __getitem__(self, idx): 393 return np.zeros([10, 2]), np.ones([10, 4]) 394 395 def __len__(self): 396 return 10 397 398 model = testing_utils.get_small_mlp( 399 num_hidden=3, num_classes=4, input_dim=2) 400 model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3)) 401 402 model.fit_generator(DummySequence(), 403 steps_per_epoch=10, 404 validation_data=custom_generator(), 405 validation_steps=1, 406 max_queue_size=10, 407 workers=0, 408 use_multiprocessing=True) 409 model.fit_generator(DummySequence(), 410 steps_per_epoch=10, 411 validation_data=custom_generator(), 412 validation_steps=1, 413 max_queue_size=10, 414 workers=0, 415 use_multiprocessing=False) 416 417 @keras_parameterized.run_with_all_model_types 418 @keras_parameterized.run_all_keras_modes 419 @data_utils.dont_use_multiprocessing_pool 420 def test_sequence_input_to_fit_eval_predict(self): 421 val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 422 423 class CustomSequence(data_utils.Sequence): 424 425 def __getitem__(self, idx): 426 return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 427 428 def __len__(self): 429 return 2 430 431 class CustomSequenceChangingBatchSize(data_utils.Sequence): 432 433 def __getitem__(self, idx): 434 batch_size = 10 - idx 435 return (np.ones([batch_size, 10], np.float32), 436 np.ones([batch_size, 1], np.float32)) 437 438 def __len__(self): 439 return 2 440 441 model = testing_utils.get_small_mlp( 442 num_hidden=10, num_classes=1, input_dim=10) 443 444 model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy') 445 model.fit(CustomSequence(), validation_data=val_data, epochs=2) 446 model.evaluate(CustomSequence()) 447 model.predict(CustomSequence()) 448 449 with self.assertRaisesRegex(ValueError, '`y` argument is not supported'): 450 model.fit(CustomSequence(), y=np.ones([10, 1])) 451 452 with self.assertRaisesRegex(ValueError, 453 '`sample_weight` argument is not supported'): 454 model.fit(CustomSequence(), sample_weight=np.ones([10, 1])) 455 456 model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy') 457 model.fit(CustomSequenceChangingBatchSize(), 458 validation_data=val_data, epochs=2) 459 model.evaluate(CustomSequenceChangingBatchSize()) 460 model.predict(CustomSequenceChangingBatchSize()) 461 462 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 463 def test_sequence_on_epoch_end(self): 464 465 class MySequence(data_utils.Sequence): 466 467 def __init__(self): 468 self.epochs = 0 469 470 def __getitem__(self, idx): 471 return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 472 473 def __len__(self): 474 return 2 475 476 def on_epoch_end(self): 477 self.epochs += 1 478 479 inputs = input_layer.Input(10) 480 outputs = layers_module.Dense(1)(inputs) 481 model = training.Model(inputs, outputs) 482 model.compile('sgd', 'mse') 483 my_seq = MySequence() 484 model.fit(my_seq, epochs=2) 485 self.assertEqual(my_seq.epochs, 2) 486 487 488@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 489class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase): 490 simple_inputs = (np.ones((10, 10)), np.ones((10, 1))) 491 nested_inputs = ((np.ones((10, 10)), np.ones((10, 20))), (np.ones((10, 1)), 492 np.ones((10, 3)))) 493 494 def _make_dataset(self, inputs, batches): 495 return dataset_ops.DatasetV2.from_tensors(inputs).repeat(batches) 496 497 def _make_iterator(self, inputs, batches): 498 return dataset_ops.make_one_shot_iterator( 499 self._make_dataset(inputs, batches)) 500 501 def _make_generator(self, inputs, batches): 502 503 def _gen(): 504 for _ in range(batches): 505 yield inputs 506 507 return _gen() 508 509 def _make_numpy(self, inputs, _): 510 return inputs 511 512 @parameterized.named_parameters( 513 ('simple_dataset', _make_dataset, simple_inputs), 514 ('simple_iterator', _make_iterator, simple_inputs), 515 ('simple_generator', _make_generator, simple_inputs), 516 ('simple_numpy', _make_numpy, simple_inputs), 517 ('nested_dataset', _make_dataset, nested_inputs), 518 ('nested_iterator', _make_iterator, nested_inputs), 519 ('nested_generator', _make_generator, nested_inputs), 520 ('nested_numpy', _make_numpy, nested_inputs)) 521 def test_convert_to_generator_like(self, input_fn, inputs): 522 expected_batches = 5 523 data = input_fn(self, inputs, expected_batches) 524 525 # Dataset and Iterator not supported in Legacy Graph mode. 526 if (not context.executing_eagerly() and 527 isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))): 528 return 529 530 generator, steps = training_generator_v1.convert_to_generator_like( 531 data, batch_size=2, steps_per_epoch=expected_batches) 532 self.assertEqual(steps, expected_batches) 533 534 for _ in range(expected_batches): 535 outputs = next(generator) 536 nest.assert_same_structure(outputs, inputs) 537 538 539if __name__ == '__main__': 540 test.main() 541