1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Part of the Keras training engine related to Python generators of array data. 16""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import math 24 25import numpy as np 26 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.data.ops import iterator_ops 29from tensorflow.python.eager import context 30from tensorflow.python.framework import errors 31from tensorflow.python.keras import backend 32from tensorflow.python.keras import callbacks as cbks 33from tensorflow.python.keras.engine import training_utils 34from tensorflow.python.keras.utils import data_utils 35from tensorflow.python.keras.utils import generic_utils 36from tensorflow.python.keras.utils.mode_keys import ModeKeys 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import nest 39 40 41def model_iteration(model, 42 data, 43 steps_per_epoch=None, 44 epochs=1, 45 verbose=1, 46 callbacks=None, 47 validation_data=None, 48 validation_steps=None, 49 validation_freq=1, 50 class_weight=None, 51 max_queue_size=10, 52 workers=1, 53 use_multiprocessing=False, 54 shuffle=False, 55 initial_epoch=0, 56 mode=ModeKeys.TRAIN, 57 batch_size=None, 58 steps_name='steps', 59 **kwargs): 60 """Loop function for arrays of data with modes TRAIN/TEST/PREDICT. 61 62 Arguments: 63 model: Keras Model instance. 64 data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or 65 `(x, y, sample_weights)`) or a generator or 66 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 67 steps_per_epoch: Total number of steps (batches of samples) before 68 declaring one epoch finished and starting the next epoch. Ignored with 69 the default value of `None`. 70 epochs: Number of times to iterate over the data. 71 verbose: Verbosity mode, 0, 1 or 2. 72 callbacks: List of callbacks to be called during training. 73 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or 74 `(x, y)` or `(x, y, sample_weights)`) or a generator or 75 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 76 validation_steps: Total number of steps (batches of samples) before 77 declaring validation finished. 78 validation_freq: Only relevant if validation data is provided. Integer or 79 `collections.Container` instance (e.g. list, tuple, etc.). If an 80 integer, specifies how many training epochs to run before a new 81 validation run is performed, e.g. `validation_freq=2` runs 82 validation every 2 epochs. If a Container, specifies the epochs on 83 which to run validation, e.g. `validation_freq=[1, 2, 10]` runs 84 validation at the end of the 1st, 2nd, and 10th epochs. 85 class_weight: Dictionary mapping class indices to a weight for the class. 86 max_queue_size: Integer. Maximum size for the generator queue. If 87 unspecified, `max_queue_size` will default to 10. 88 workers: Integer. Maximum number of processes to spin up when using 89 process-based threading. If unspecified, `workers` will default to 1. If 90 0, will execute the generator on the main thread. 91 use_multiprocessing: Boolean. If `True`, use process-based threading. If 92 unspecified, `use_multiprocessing` will default to `False`. Note that 93 because this implementation relies on multiprocessing, you should not 94 pass non-picklable arguments to the generator as they can't be passed 95 easily to children processes. 96 shuffle: Boolean. Whether to shuffle the order of the batches at the 97 beginning of each epoch. Only used with instances of `Sequence` 98 (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not 99 `None`. 100 initial_epoch: Epoch at which to start training (useful for resuming a 101 previous training run). 102 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 103 batch_size: Integer batch size or None if unknown. Will only be used if 104 `data` is in NumPy/Tensor format. 105 steps_name: The string name of the steps argument, either `steps`, 106 `validation_steps`, or `steps_per_epoch`. Only used for error message 107 formatting. 108 **kwargs: Additional arguments for backwards compatibility. `steps` is 109 accepted as an alias for `steps_per_epoch`. 110 111 Returns: 112 - In TRAIN mode: `History` object. 113 - In TEST mode: Evaluation metrics. 114 - In PREDICT mode: Outputs of the Model called on inputs. 115 116 Raises: 117 ValueError: in case of invalid arguments. 118 """ 119 if 'steps' in kwargs: 120 steps_per_epoch = kwargs['steps'] 121 122 # Determine the number of steps per epoch and whether we should reset the 123 # dataset at the end of each epoch. 124 reset_dataset_after_each_epoch = False 125 original_dataset = None 126 is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1)) 127 if is_dataset: 128 original_dataset = data 129 if steps_per_epoch is None: 130 reset_dataset_after_each_epoch = True 131 steps_per_epoch = training_utils.infer_steps_for_dataset( 132 data, steps_per_epoch, epochs=epochs, steps_name=steps_name) 133 134 # Convert to a format that supports `next(generator)`. 135 generator, steps_per_epoch = convert_to_generator_like( 136 data, 137 steps_per_epoch=steps_per_epoch, 138 batch_size=batch_size, 139 epochs=epochs - initial_epoch, 140 shuffle=shuffle) 141 142 do_validation = validation_data is not None 143 is_sequence = isinstance(generator, data_utils.Sequence) 144 _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 145 steps_per_epoch, validation_data, validation_steps, mode, 146 kwargs) 147 148 batch_function = _make_execution_function( 149 model, mode, class_weight=class_weight) 150 151 # Create the queue for the generator. 152 enqueuer = None 153 if not is_dataset: 154 generator, enqueuer = _make_enqueued_generator( 155 generator, 156 workers=workers, 157 use_multiprocessing=use_multiprocessing, 158 max_queue_size=max_queue_size, 159 shuffle=shuffle) 160 161 num_samples_or_steps, use_steps = _get_num_samples_or_steps( 162 data, steps_per_epoch) 163 164 count_mode = 'steps' if use_steps else 'samples' 165 callbacks = cbks.configure_callbacks( 166 callbacks, 167 model, 168 do_validation=do_validation, 169 epochs=epochs, 170 steps_per_epoch=steps_per_epoch, 171 batch_size=batch_size, 172 samples=num_samples_or_steps, 173 verbose=0, # Handle ProgBar as part of Callbacks once hooks are ready. 174 mode=mode) 175 # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready. 176 progbar = training_utils.get_progbar(model, count_mode) 177 progbar.params = callbacks.params 178 progbar.params['verbose'] = verbose 179 180 if mode == ModeKeys.PREDICT: 181 aggregator = training_utils.OutputsAggregator(True, steps_per_epoch) 182 else: 183 aggregator = training_utils.MetricsAggregator(True, steps_per_epoch) 184 185 should_set_learning_phase = context.executing_eagerly() and model.run_eagerly 186 if should_set_learning_phase: 187 old_learning_phase = backend.learning_phase() 188 backend.set_eager_learning_phase(1 if mode == ModeKeys.TRAIN else 0) 189 190 callbacks.model.stop_training = False 191 callbacks._call_begin_hook(mode) 192 progbar.on_train_begin() 193 for epoch in range(initial_epoch, epochs): 194 if callbacks.model.stop_training: 195 break 196 197 # Setup work for each epoch. 198 model.reset_metrics() 199 epoch_logs = {} 200 if mode == ModeKeys.TRAIN: 201 callbacks.on_epoch_begin(epoch, epoch_logs) 202 progbar.on_epoch_begin(epoch, epoch_logs) 203 204 if steps_per_epoch is None: 205 # Loop over dataset until `OutOfRangeError` is raised. 206 target_steps = np.inf 207 else: 208 # Loop over dataset for the specified number of steps. 209 target_steps = steps_per_epoch 210 211 step = 0 212 while step < target_steps: 213 batch_data = _get_next_batch(generator, mode) 214 if batch_data is None: 215 if is_dataset: 216 # The dataset passed by the user ran out of batches. 217 # Now we know the cardinality of the dataset. 218 # If steps_per_epoch was specified, then running out of data is 219 # unexpected, so we stop training and inform the user. 220 if steps_per_epoch: 221 callbacks.model.stop_training = True 222 logging.warning( 223 'Your dataset ran out of data; interrupting training. ' 224 'Make sure that your dataset can generate at least ' 225 '`%s * epochs` batches (in this case, %d batches). ' 226 'You may need to use the repeat() function when ' 227 'building your dataset.' 228 % (steps_name, steps_per_epoch * epochs)) 229 elif step > 0: 230 steps_per_epoch = step 231 aggregator.num_samples_or_steps = steps_per_epoch 232 if mode == ModeKeys.TRAIN: 233 progbar.params['steps'] = steps_per_epoch 234 progbar.progbar.target = steps_per_epoch 235 else: 236 # We ran out of batches while the user passed an iterator (legacy). 237 callbacks.model.stop_training = True 238 logging.warning( 239 'Your dataset iterator ran out of data; ' 240 'interrupting training. Make sure that your iterator ' 241 'can generate at least `%s * epochs` ' 242 'batches (in this case, %d batches). You may need to' 243 'use the repeat() function when building your ' 244 'dataset.' % (steps_name, steps_per_epoch * epochs)) 245 break 246 247 # `batch_size` used for validation data if validation 248 # data is NumPy/EagerTensors. 249 batch_size = int(nest.flatten(batch_data)[0].shape[0]) 250 251 # Callbacks batch begin. 252 batch_logs = {'batch': step, 'size': batch_size} 253 callbacks._call_batch_hook(mode, 'begin', step, batch_logs) 254 progbar.on_batch_begin(step, batch_logs) 255 256 is_deferred = not model._is_compiled 257 batch_outs = batch_function(*batch_data) 258 if not isinstance(batch_outs, list): 259 batch_outs = [batch_outs] 260 261 if step == 0: 262 aggregator.create(batch_outs) 263 264 if is_deferred: 265 # Set callbacks params. We do this here when model is compiled only 266 # in the first iteration of this loop (deferred build scenario). 267 cbks.set_callback_parameters( 268 callbacks, 269 model, 270 do_validation=do_validation, 271 batch_size=batch_size, 272 epochs=epochs, 273 steps_per_epoch=steps_per_epoch, 274 samples=num_samples_or_steps, 275 verbose=verbose, 276 mode=mode) 277 278 progbar.params = callbacks.params 279 progbar.params['verbose'] = verbose 280 281 # Aggregate results. 282 aggregator.aggregate(batch_outs) 283 284 # Callbacks batch end. 285 batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode) 286 callbacks._call_batch_hook(mode, 'end', step, batch_logs) 287 progbar.on_batch_end(step, batch_logs) 288 step += 1 289 290 if callbacks.model.stop_training: 291 break 292 293 aggregator.finalize() 294 results = aggregator.results 295 epoch_logs = cbks.make_logs(model, epoch_logs, results, mode) 296 if len(results) == 1: 297 results = results[0] 298 299 # Run the test loop every epoch during training. 300 if (do_validation and 301 training_utils.should_run_validation(validation_freq, epoch) and 302 not callbacks.model.stop_training): 303 val_results = model_iteration( 304 model, 305 validation_data, 306 steps_per_epoch=validation_steps, 307 batch_size=batch_size, 308 class_weight=class_weight, 309 workers=workers, 310 use_multiprocessing=use_multiprocessing, 311 max_queue_size=max_queue_size, 312 callbacks=callbacks, 313 verbose=0, 314 mode=ModeKeys.TEST, 315 steps_name='validation_steps') 316 317 if not isinstance(val_results, list): 318 val_results = [val_results] 319 epoch_logs = cbks.make_logs( 320 model, epoch_logs, val_results, mode, prefix='val_') 321 322 if mode == ModeKeys.TRAIN: 323 # Epochs only apply to `fit`. 324 callbacks.on_epoch_end(epoch, epoch_logs) 325 progbar.on_epoch_end(epoch, epoch_logs) 326 327 # Recreate dataset iterator for the next epoch. 328 if reset_dataset_after_each_epoch and epoch < epochs - 1: 329 generator = dataset_ops.make_one_shot_iterator(original_dataset) 330 331 callbacks._call_end_hook(mode) 332 333 if enqueuer is not None: 334 enqueuer.stop() 335 336 if should_set_learning_phase: 337 backend.set_eager_learning_phase(old_learning_phase) 338 339 if mode == ModeKeys.TRAIN: 340 return model.history 341 return results 342 343 344# Maintain compatibility with the existing names. 345fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN) 346evaluate_generator = functools.partial( 347 model_iteration, mode=ModeKeys.TEST, shuffle=False) 348predict_generator = functools.partial( 349 model_iteration, mode=ModeKeys.PREDICT, shuffle=False) 350 351 352def _get_next_batch(generator, mode): 353 """Retrieves the next batch of input data.""" 354 try: 355 generator_output = next(generator) 356 except (StopIteration, errors.OutOfRangeError): 357 return None 358 if not isinstance(generator_output, tuple): 359 if mode == ModeKeys.PREDICT: 360 # Always wrap in a tuple. 361 return (generator_output,) 362 else: 363 raise ValueError('Output of generator should be ' 364 'a tuple `(x, y, sample_weight)` ' 365 'or `(x, y)`. Found: ' + str(generator_output)) 366 367 if len(generator_output) < 1 or len(generator_output) > 3: 368 raise ValueError('Output of generator should be ' 369 'a tuple `(x, y, sample_weight)` ' 370 'or `(x, y)` or (x,). Found: ' + str(generator_output)) 371 return generator_output 372 373 374def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers, 375 steps_per_epoch, validation_data, validation_steps, 376 mode, kwargs): 377 """Raises errors if arguments are invalid. 378 379 Arguments: 380 is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence` 381 instance. 382 is_dataset: Boolean, whether data is a dataset instance. 383 use_multiprocessing: Boolean. If `True`, use process-based threading. If 384 unspecified, `use_multiprocessing` will default to `False`. Note that 385 because this implementation relies on multiprocessing, you should not pass 386 non-picklable arguments to the generator as they can't be passed easily to 387 children processes. 388 workers: Integer. Maximum number of processes to spin up when using 389 process-based threading. If unspecified, `workers` will default to 1. If 390 0, will execute the generator on the main thread. 391 steps_per_epoch: Total number of steps (batches of samples) before declaring 392 one epoch finished and starting the next epoch. Ignored with the default 393 value of `None`. 394 validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, 395 y)` or `(x, y, sample_weights)`) or a generator or 396 `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset. 397 validation_steps: Total number of steps (batches of samples) before 398 declaring validation finished. 399 mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. 400 kwargs: Additional arguments for backwards compatibility. 401 402 Raises: 403 ValueError: If `steps_per_epoch` or `validation_steps` are not passed 404 for data types that require them, or if unrecognized keyword 405 arguments are passed. 406 """ 407 if not is_sequence and use_multiprocessing and workers > 1: 408 logging.warning( 409 UserWarning('Using a generator with `use_multiprocessing=True`' 410 ' and multiple workers may duplicate your data.' 411 ' Please consider using the `keras.utils.Sequence`' 412 ' class.')) 413 414 if steps_per_epoch is None and not is_dataset: 415 arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps' 416 raise ValueError('Please specify the number of steps via the ' 417 '`{}` argument.'.format(arg_name)) 418 419 val_gen = ( 420 data_utils.is_generator_or_sequence(validation_data) or 421 isinstance(validation_data, iterator_ops.EagerIterator)) 422 if (val_gen and not isinstance(validation_data, data_utils.Sequence) and 423 not validation_steps): 424 raise ValueError('Please specify the `validation_steps` argument.') 425 426 if any(k != 'steps' for k in kwargs): 427 raise ValueError('Invalid arguments passed: {}'.format( 428 [k for k in kwargs if k != 'steps'])) 429 430 431def convert_to_generator_like(data, 432 batch_size=None, 433 steps_per_epoch=None, 434 epochs=1, 435 shuffle=False): 436 """Make a generator out of NumPy or EagerTensor inputs. 437 438 Arguments: 439 data: Either a generator or `keras.utils.data_utils.Sequence` object or 440 `Dataset` or `EagerIterator` or a {1,2,3}-tuple of NumPy arrays or 441 EagerTensors. If a tuple, the elements represent `(x, y, sample_weights)` 442 and may be `None` or `[None]`. 443 batch_size: Used when creating a generator out of tuples of NumPy arrays or 444 EagerTensors. 445 steps_per_epoch: Steps of the generator to run each epoch. If `None` the 446 number of steps will be read from the data (for 447 `keras.utils.data_utils.Sequence` types). 448 epochs: Total number of epochs to run. 449 shuffle: Whether the data should be shuffled. 450 451 Returns: 452 - Generator or `keras.utils.data_utils.Sequence` or EagerIterator. 453 454 Raises: 455 - ValueError: If `batch_size` is not provided for NumPy or EagerTensor 456 inputs. 457 """ 458 if isinstance(data, tuple): 459 # Scrub `Nones` that might have been passed for `targets`, `sample_weights`. 460 data = tuple( 461 ele for ele in data if not all(e is None for e in nest.flatten(ele))) 462 if len(data) == 1: 463 data = data[0] 464 465 if data_utils.is_generator_or_sequence(data) or isinstance( 466 data, iterator_ops.EagerIterator): 467 if isinstance(data, data_utils.Sequence): 468 if steps_per_epoch is None: 469 steps_per_epoch = len(data) 470 return data, steps_per_epoch 471 if isinstance(data, dataset_ops.DatasetV2): 472 return dataset_ops.make_one_shot_iterator(data), steps_per_epoch 473 474 # Create generator from NumPy or EagerTensor Input. 475 num_samples = int(nest.flatten(data)[0].shape[0]) 476 if batch_size is None: 477 raise ValueError('You must specify `batch_size`') 478 steps_per_epoch = int(math.ceil(num_samples / batch_size)) 479 480 def _gen(data): 481 """Makes a generator out of a structure of NumPy/EagerTensors.""" 482 index_array = np.arange(num_samples) 483 for _ in range(epochs): 484 if shuffle: 485 np.random.shuffle(index_array) 486 batches = generic_utils.make_batches(num_samples, batch_size) 487 for (batch_start, batch_end) in batches: 488 batch_ids = index_array[batch_start:batch_end] 489 flat_batch_data = training_utils.slice_arrays( 490 nest.flatten(data), batch_ids, contiguous=(not shuffle)) 491 yield nest.pack_sequence_as(data, flat_batch_data) 492 493 return _gen(data), steps_per_epoch 494 495 496def _make_enqueued_generator(generator, 497 workers=1, 498 use_multiprocessing=False, 499 max_queue_size=10, 500 shuffle=False): 501 """Create a buffered queue of next elements of the generator.""" 502 is_sequence = isinstance(generator, data_utils.Sequence) 503 enqueuer = None 504 if workers > 0: 505 if is_sequence: 506 enqueuer = data_utils.OrderedEnqueuer( 507 generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle) 508 else: 509 enqueuer = data_utils.GeneratorEnqueuer( 510 generator, use_multiprocessing=use_multiprocessing) 511 enqueuer.start(workers=workers, max_queue_size=max_queue_size) 512 output_generator = enqueuer.get() 513 else: 514 if is_sequence: 515 output_generator = data_utils.iter_sequence_infinite(generator) 516 else: 517 output_generator = generator 518 return output_generator, enqueuer 519 520 521def _make_execution_function(model, mode, class_weight=None): 522 """Makes function to run one step of model execution.""" 523 if mode == ModeKeys.TRAIN: 524 f = functools.partial(model.train_on_batch, class_weight=class_weight) 525 elif mode == ModeKeys.TEST: 526 f = model.test_on_batch 527 else: 528 # Match signature of other modes to allow 529 # 1, 2, or 3-tuples from generator 530 def predict_on_batch(x, y=None, sample_weights=None): # pylint: disable=unused-argument 531 return model.predict_on_batch(x) 532 533 f = predict_on_batch 534 535 # Maintain stateful metrics across batch-level calls. 536 if mode != ModeKeys.PREDICT: 537 f = functools.partial(f, reset_metrics=False) 538 539 return f 540 541 542def _get_num_samples_or_steps(data, steps_per_epoch): 543 """Returns number of samples or steps, and whether to use steps count mode.""" 544 flat_inputs = nest.flatten(data) 545 if hasattr(flat_inputs[0], 'shape'): 546 return int(flat_inputs[0].shape[0]), False 547 return steps_per_epoch, True 548