1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for training routines.""" 16 17import itertools 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.ops import iterator_ops 24from tensorflow.python.eager import context 25from tensorflow.python.keras import combinations 26from tensorflow.python.keras import keras_parameterized 27from tensorflow.python.keras import layers as layers_module 28from tensorflow.python.keras import losses 29from tensorflow.python.keras import metrics as metrics_module 30from tensorflow.python.keras import testing_utils 31from tensorflow.python.keras.engine import input_layer 32from tensorflow.python.keras.engine import training 33from tensorflow.python.keras.engine import training_generator_v1 34from tensorflow.python.keras.optimizer_v2 import rmsprop 35from tensorflow.python.keras.utils import data_utils 36from tensorflow.python.platform import test 37from tensorflow.python.util import nest 38 39 40def custom_generator(mode=2): 41 batch_size = 10 42 num_samples = 50 43 arr_data = np.random.random((num_samples, 2)) 44 arr_labels = np.random.random((num_samples, 4)) 45 arr_weights = np.random.random((num_samples,)) 46 i = 0 47 while True: 48 batch_index = i * batch_size % num_samples 49 i += 1 50 start = batch_index 51 end = start + batch_size 52 x = arr_data[start: end] 53 y = arr_labels[start: end] 54 w = arr_weights[start: end] 55 if mode == 1: 56 yield x 57 elif mode == 2: 58 yield x, y 59 else: 60 yield x, y, w 61 62 63def custom_generator_changing_batch_size(mode=2): 64 batch_size = 10 65 cur_batch_size = 11 66 num_samples = 50 67 arr_data = np.random.random((num_samples, 2)) 68 arr_labels = np.random.random((num_samples, 4)) 69 arr_weights = np.random.random((num_samples,)) 70 i = 0 71 while True: 72 if cur_batch_size > 1: 73 cur_batch_size -= 1 74 batch_index = i * batch_size % num_samples 75 i += 1 76 start = batch_index 77 end = start + cur_batch_size 78 x = arr_data[start: end] 79 y = arr_labels[start: end] 80 w = arr_weights[start: end] 81 if mode == 1: 82 yield x 83 elif mode == 2: 84 yield x, y 85 else: 86 yield x, y, w 87 88custom_generator_threads = data_utils.threadsafe_generator(custom_generator) 89 90 91class TestGeneratorMethods(keras_parameterized.TestCase): 92 93 @keras_parameterized.run_with_all_model_types 94 @keras_parameterized.run_all_keras_modes 95 @data_utils.dont_use_multiprocessing_pool 96 def test_fit_generator_method(self): 97 model = testing_utils.get_small_mlp( 98 num_hidden=3, num_classes=4, input_dim=2) 99 model.compile( 100 loss='mse', 101 optimizer=rmsprop.RMSprop(1e-3), 102 metrics=['mae', metrics_module.CategoricalAccuracy()]) 103 104 model.fit_generator(custom_generator_threads(), 105 steps_per_epoch=5, 106 epochs=1, 107 verbose=1, 108 max_queue_size=10, 109 workers=4, 110 use_multiprocessing=True) 111 model.fit_generator(custom_generator(), 112 steps_per_epoch=5, 113 epochs=1, 114 verbose=1, 115 max_queue_size=10, 116 use_multiprocessing=False) 117 model.fit_generator(custom_generator(), 118 steps_per_epoch=5, 119 epochs=1, 120 verbose=1, 121 max_queue_size=10, 122 use_multiprocessing=False, 123 validation_data=custom_generator(), 124 validation_steps=10) 125 model.fit_generator(custom_generator(), 126 steps_per_epoch=5, 127 validation_data=custom_generator(), 128 validation_steps=1, 129 workers=0) 130 131 @keras_parameterized.run_with_all_model_types 132 @keras_parameterized.run_all_keras_modes 133 @data_utils.dont_use_multiprocessing_pool 134 def test_evaluate_generator_method(self): 135 model = testing_utils.get_small_mlp( 136 num_hidden=3, num_classes=4, input_dim=2) 137 model.compile( 138 loss='mse', 139 optimizer=rmsprop.RMSprop(1e-3), 140 metrics=['mae', metrics_module.CategoricalAccuracy()], 141 run_eagerly=testing_utils.should_run_eagerly()) 142 143 model.evaluate_generator(custom_generator_threads(), 144 steps=5, 145 max_queue_size=10, 146 workers=2, 147 verbose=1, 148 use_multiprocessing=True) 149 model.evaluate_generator(custom_generator(), 150 steps=5, 151 max_queue_size=10, 152 use_multiprocessing=False) 153 model.evaluate_generator(custom_generator(), 154 steps=5, 155 max_queue_size=10, 156 use_multiprocessing=False, 157 workers=0) 158 159 @keras_parameterized.run_with_all_model_types 160 @keras_parameterized.run_all_keras_modes 161 @data_utils.dont_use_multiprocessing_pool 162 def test_predict_generator_method(self): 163 model = testing_utils.get_small_mlp( 164 num_hidden=3, num_classes=4, input_dim=2) 165 model.run_eagerly = testing_utils.should_run_eagerly() 166 167 model.predict_generator(custom_generator_threads(), 168 steps=5, 169 max_queue_size=10, 170 workers=2, 171 use_multiprocessing=True) 172 model.predict_generator(custom_generator(), 173 steps=5, 174 max_queue_size=10, 175 use_multiprocessing=False) 176 model.predict_generator(custom_generator(), 177 steps=5, 178 max_queue_size=10, 179 workers=0) 180 # Test generator with just inputs (no targets) 181 model.predict_generator(custom_generator_threads(mode=1), 182 steps=5, 183 max_queue_size=10, 184 workers=2, 185 use_multiprocessing=True) 186 model.predict_generator(custom_generator(mode=1), 187 steps=5, 188 max_queue_size=10, 189 use_multiprocessing=False) 190 model.predict_generator(custom_generator(mode=1), 191 steps=5, 192 max_queue_size=10, 193 workers=0) 194 195 @keras_parameterized.run_with_all_model_types 196 @keras_parameterized.run_all_keras_modes 197 def test_generator_methods_with_sample_weights(self): 198 model = testing_utils.get_small_mlp( 199 num_hidden=3, num_classes=4, input_dim=2) 200 model.compile( 201 loss='mse', 202 optimizer=rmsprop.RMSprop(1e-3), 203 metrics=['mae', metrics_module.CategoricalAccuracy()], 204 run_eagerly=testing_utils.should_run_eagerly()) 205 206 model.fit_generator(custom_generator(mode=3), 207 steps_per_epoch=5, 208 epochs=1, 209 verbose=1, 210 max_queue_size=10, 211 use_multiprocessing=False) 212 model.fit_generator(custom_generator(mode=3), 213 steps_per_epoch=5, 214 epochs=1, 215 verbose=1, 216 max_queue_size=10, 217 use_multiprocessing=False, 218 validation_data=custom_generator(mode=3), 219 validation_steps=10) 220 model.predict_generator(custom_generator(mode=3), 221 steps=5, 222 max_queue_size=10, 223 use_multiprocessing=False) 224 model.evaluate_generator(custom_generator(mode=3), 225 steps=5, 226 max_queue_size=10, 227 use_multiprocessing=False) 228 229 @keras_parameterized.run_with_all_model_types 230 @keras_parameterized.run_all_keras_modes 231 def test_generator_methods_invalid_use_case(self): 232 def invalid_generator(): 233 while 1: 234 yield (0, 0, 0, 0) 235 236 model = testing_utils.get_small_mlp( 237 num_hidden=3, num_classes=4, input_dim=2) 238 model.compile( 239 loss='mse', 240 optimizer=rmsprop.RMSprop(1e-3), 241 run_eagerly=testing_utils.should_run_eagerly()) 242 243 with self.assertRaises(ValueError): 244 model.fit_generator(invalid_generator(), 245 steps_per_epoch=5, 246 epochs=1, 247 verbose=1, 248 max_queue_size=10, 249 use_multiprocessing=False) 250 with self.assertRaises(ValueError): 251 model.fit_generator(custom_generator(), 252 steps_per_epoch=5, 253 epochs=1, 254 verbose=1, 255 max_queue_size=10, 256 use_multiprocessing=False, 257 validation_data=invalid_generator(), 258 validation_steps=10) 259 with self.assertRaises(ValueError): 260 model.predict_generator(invalid_generator(), 261 steps=5, 262 max_queue_size=10, 263 use_multiprocessing=False) 264 with self.assertRaises(ValueError): 265 model.evaluate_generator(invalid_generator(), 266 steps=5, 267 max_queue_size=10, 268 use_multiprocessing=False) 269 270 @keras_parameterized.run_with_all_model_types 271 @keras_parameterized.run_all_keras_modes 272 def test_generator_input_to_fit_eval_predict(self): 273 val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 274 275 def ones_generator(): 276 while True: 277 yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 278 279 model = testing_utils.get_small_mlp( 280 num_hidden=10, num_classes=1, input_dim=10) 281 282 model.compile( 283 rmsprop.RMSprop(0.001), 284 'binary_crossentropy', 285 run_eagerly=testing_utils.should_run_eagerly()) 286 model.fit( 287 ones_generator(), 288 steps_per_epoch=2, 289 validation_data=val_data, 290 epochs=2) 291 model.evaluate(ones_generator(), steps=2) 292 model.predict(ones_generator(), steps=2) 293 294 # Test with a changing batch size 295 model = testing_utils.get_small_mlp( 296 num_hidden=3, num_classes=4, input_dim=2) 297 model.compile( 298 loss='mse', 299 optimizer=rmsprop.RMSprop(1e-3), 300 metrics=['mae', metrics_module.CategoricalAccuracy()]) 301 model.fit_generator(custom_generator_changing_batch_size(), 302 steps_per_epoch=5, 303 epochs=1, 304 verbose=1, 305 max_queue_size=10, 306 use_multiprocessing=False) 307 model.fit_generator(custom_generator_changing_batch_size(), 308 steps_per_epoch=5, 309 epochs=1, 310 verbose=1, 311 max_queue_size=10, 312 use_multiprocessing=False, 313 validation_data=custom_generator_changing_batch_size(), 314 validation_steps=10) 315 316 model.fit( 317 custom_generator_changing_batch_size(), 318 steps_per_epoch=5, 319 validation_data=custom_generator_changing_batch_size(), 320 validation_steps=10, 321 epochs=2) 322 model.evaluate(custom_generator_changing_batch_size(), steps=5) 323 model.predict(custom_generator_changing_batch_size(), steps=5) 324 325 @keras_parameterized.run_with_all_model_types 326 @keras_parameterized.run_all_keras_modes 327 @data_utils.dont_use_multiprocessing_pool 328 def test_generator_dynamic_shapes(self): 329 330 x = [ 331 'I think juice is great', 332 'unknown is the best language since slicedbread', 333 'a a a a a a a', 334 'matmul' 335 'Yaks are also quite nice', 336 ] 337 y = [1, 0, 0, 1, 1] 338 339 vocab = { 340 word: i + 1 for i, word in 341 enumerate( 342 sorted(set(itertools.chain(*[i.split() for i in x])))) 343 } 344 345 def data_gen(batch_size=2): 346 np.random.seed(0) 347 data = list(zip(x, y)) * 10 348 np.random.shuffle(data) 349 350 def pack_and_pad(queue): 351 x = [[vocab[j] for j in i[0].split()] for i in queue] 352 pad_len = max(len(i) for i in x) 353 x = np.array([i + [0] * (pad_len - len(i)) for i in x]) 354 y = np.array([i[1] for i in queue]) 355 del queue[:] 356 return x, y[:, np.newaxis] 357 358 queue = [] 359 for i, element in enumerate(data): 360 queue.append(element) 361 if not (i + 1) % batch_size: 362 yield pack_and_pad(queue) 363 364 if queue: 365 # Last partial batch 366 yield pack_and_pad(queue) 367 368 model = testing_utils.get_model_from_layers([ 369 layers_module.Embedding(input_dim=len(vocab) + 1, output_dim=4), 370 layers_module.SimpleRNN(units=1), 371 layers_module.Activation('sigmoid') 372 ], 373 input_shape=(None,)) 374 375 model.compile(loss=losses.binary_crossentropy, optimizer='sgd') 376 model.fit(data_gen(), epochs=1, steps_per_epoch=5) 377 378 379class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase): 380 381 @keras_parameterized.run_with_all_model_types 382 @keras_parameterized.run_all_keras_modes 383 @data_utils.dont_use_multiprocessing_pool 384 def test_training_with_sequences(self): 385 386 class DummySequence(data_utils.Sequence): 387 388 def __getitem__(self, idx): 389 return np.zeros([10, 2]), np.ones([10, 4]) 390 391 def __len__(self): 392 return 10 393 394 model = testing_utils.get_small_mlp( 395 num_hidden=3, num_classes=4, input_dim=2) 396 model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3)) 397 398 model.fit_generator(DummySequence(), 399 steps_per_epoch=10, 400 validation_data=custom_generator(), 401 validation_steps=1, 402 max_queue_size=10, 403 workers=0, 404 use_multiprocessing=True) 405 model.fit_generator(DummySequence(), 406 steps_per_epoch=10, 407 validation_data=custom_generator(), 408 validation_steps=1, 409 max_queue_size=10, 410 workers=0, 411 use_multiprocessing=False) 412 413 @keras_parameterized.run_with_all_model_types 414 @keras_parameterized.run_all_keras_modes 415 @data_utils.dont_use_multiprocessing_pool 416 def test_sequence_input_to_fit_eval_predict(self): 417 val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 418 419 class CustomSequence(data_utils.Sequence): 420 421 def __getitem__(self, idx): 422 return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 423 424 def __len__(self): 425 return 2 426 427 class CustomSequenceChangingBatchSize(data_utils.Sequence): 428 429 def __getitem__(self, idx): 430 batch_size = 10 - idx 431 return (np.ones([batch_size, 10], np.float32), 432 np.ones([batch_size, 1], np.float32)) 433 434 def __len__(self): 435 return 2 436 437 model = testing_utils.get_small_mlp( 438 num_hidden=10, num_classes=1, input_dim=10) 439 440 model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy') 441 model.fit(CustomSequence(), validation_data=val_data, epochs=2) 442 model.evaluate(CustomSequence()) 443 model.predict(CustomSequence()) 444 445 with self.assertRaisesRegex(ValueError, '`y` argument is not supported'): 446 model.fit(CustomSequence(), y=np.ones([10, 1])) 447 448 with self.assertRaisesRegex(ValueError, 449 '`sample_weight` argument is not supported'): 450 model.fit(CustomSequence(), sample_weight=np.ones([10, 1])) 451 452 model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy') 453 model.fit(CustomSequenceChangingBatchSize(), 454 validation_data=val_data, epochs=2) 455 model.evaluate(CustomSequenceChangingBatchSize()) 456 model.predict(CustomSequenceChangingBatchSize()) 457 458 @keras_parameterized.run_all_keras_modes(always_skip_v1=True) 459 def test_sequence_on_epoch_end(self): 460 461 class MySequence(data_utils.Sequence): 462 463 def __init__(self): 464 self.epochs = 0 465 466 def __getitem__(self, idx): 467 return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 468 469 def __len__(self): 470 return 2 471 472 def on_epoch_end(self): 473 self.epochs += 1 474 475 inputs = input_layer.Input(10) 476 outputs = layers_module.Dense(1)(inputs) 477 model = training.Model(inputs, outputs) 478 model.compile('sgd', 'mse') 479 my_seq = MySequence() 480 model.fit(my_seq, epochs=2) 481 self.assertEqual(my_seq.epochs, 2) 482 483 484@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 485class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase): 486 simple_inputs = (np.ones((10, 10)), np.ones((10, 1))) 487 nested_inputs = ((np.ones((10, 10)), np.ones((10, 20))), (np.ones((10, 1)), 488 np.ones((10, 3)))) 489 490 def _make_dataset(self, inputs, batches): 491 return dataset_ops.DatasetV2.from_tensors(inputs).repeat(batches) 492 493 def _make_iterator(self, inputs, batches): 494 return dataset_ops.make_one_shot_iterator( 495 self._make_dataset(inputs, batches)) 496 497 def _make_generator(self, inputs, batches): 498 499 def _gen(): 500 for _ in range(batches): 501 yield inputs 502 503 return _gen() 504 505 def _make_numpy(self, inputs, _): 506 return inputs 507 508 @parameterized.named_parameters( 509 ('simple_dataset', _make_dataset, simple_inputs), 510 ('simple_iterator', _make_iterator, simple_inputs), 511 ('simple_generator', _make_generator, simple_inputs), 512 ('simple_numpy', _make_numpy, simple_inputs), 513 ('nested_dataset', _make_dataset, nested_inputs), 514 ('nested_iterator', _make_iterator, nested_inputs), 515 ('nested_generator', _make_generator, nested_inputs), 516 ('nested_numpy', _make_numpy, nested_inputs)) 517 def test_convert_to_generator_like(self, input_fn, inputs): 518 expected_batches = 5 519 data = input_fn(self, inputs, expected_batches) 520 521 # Dataset and Iterator not supported in Legacy Graph mode. 522 if (not context.executing_eagerly() and 523 isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))): 524 return 525 526 generator, steps = training_generator_v1.convert_to_generator_like( 527 data, batch_size=2, steps_per_epoch=expected_batches) 528 self.assertEqual(steps, expected_batches) 529 530 for _ in range(expected_batches): 531 outputs = next(generator) 532 nest.assert_same_structure(outputs, inputs) 533 534 535if __name__ == '__main__': 536 test.main() 537