• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to Python generators of array data.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import math
24
25import numpy as np
26
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import iterator_ops
29from tensorflow.python.eager import context
30from tensorflow.python.framework import errors
31from tensorflow.python.keras import backend
32from tensorflow.python.keras import callbacks as cbks
33from tensorflow.python.keras.engine import training_utils
34from tensorflow.python.keras.utils import data_utils
35from tensorflow.python.keras.utils import generic_utils
36from tensorflow.python.keras.utils.mode_keys import ModeKeys
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import nest
39
40
41def model_iteration(model,
42                    data,
43                    steps_per_epoch=None,
44                    epochs=1,
45                    verbose=1,
46                    callbacks=None,
47                    validation_data=None,
48                    validation_steps=None,
49                    validation_freq=1,
50                    class_weight=None,
51                    max_queue_size=10,
52                    workers=1,
53                    use_multiprocessing=False,
54                    shuffle=False,
55                    initial_epoch=0,
56                    mode=ModeKeys.TRAIN,
57                    batch_size=None,
58                    steps_name='steps',
59                    **kwargs):
60  """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
61
62  Arguments:
63      model: Keras Model instance.
64      data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
65        `(x, y, sample_weights)`) or a generator or
66        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
67      steps_per_epoch: Total number of steps (batches of samples) before
68        declaring one epoch finished and starting the next epoch. Ignored with
69        the default value of `None`.
70      epochs: Number of times to iterate over the data.
71      verbose: Verbosity mode, 0, 1 or 2.
72      callbacks: List of callbacks to be called during training.
73      validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
74        `(x, y)` or `(x, y, sample_weights)`) or a generator or
75        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
76      validation_steps: Total number of steps (batches of samples) before
77        declaring validation finished.
78      validation_freq: Only relevant if validation data is provided. Integer or
79        `collections.Container` instance (e.g. list, tuple, etc.). If an
80        integer, specifies how many training epochs to run before a new
81        validation run is performed, e.g. `validation_freq=2` runs
82        validation every 2 epochs. If a Container, specifies the epochs on
83        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
84        validation at the end of the 1st, 2nd, and 10th epochs.
85      class_weight: Dictionary mapping class indices to a weight for the class.
86      max_queue_size: Integer. Maximum size for the generator queue. If
87        unspecified, `max_queue_size` will default to 10.
88      workers: Integer. Maximum number of processes to spin up when using
89        process-based threading. If unspecified, `workers` will default to 1. If
90        0, will execute the generator on the main thread.
91      use_multiprocessing: Boolean. If `True`, use process-based threading. If
92        unspecified, `use_multiprocessing` will default to `False`. Note that
93        because this implementation relies on multiprocessing, you should not
94        pass non-picklable arguments to the generator as they can't be passed
95        easily to children processes.
96      shuffle: Boolean. Whether to shuffle the order of the batches at the
97        beginning of each epoch. Only used with instances of `Sequence`
98        (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
99        `None`.
100      initial_epoch: Epoch at which to start training (useful for resuming a
101        previous training run).
102      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
103      batch_size: Integer batch size or None if unknown. Will only be used if
104        `data` is in NumPy/Tensor format.
105      steps_name: The string name of the steps argument, either `steps`,
106        `validation_steps`, or `steps_per_epoch`. Only used for error message
107        formatting.
108      **kwargs: Additional arguments for backwards compatibility. `steps` is
109        accepted as an alias for `steps_per_epoch`.
110
111  Returns:
112      - In TRAIN mode: `History` object.
113      - In TEST mode: Evaluation metrics.
114      - In PREDICT mode: Outputs of the Model called on inputs.
115
116  Raises:
117      ValueError: in case of invalid arguments.
118  """
119  if 'steps' in kwargs:
120    steps_per_epoch = kwargs['steps']
121
122  # Determine the number of steps per epoch and whether we should reset the
123  # dataset at the end of each epoch.
124  reset_dataset_after_each_epoch = False
125  original_dataset = None
126  is_dataset = isinstance(data, (dataset_ops.DatasetV2, dataset_ops.DatasetV1))
127  if is_dataset:
128    original_dataset = data
129    if steps_per_epoch is None:
130      reset_dataset_after_each_epoch = True
131      steps_per_epoch = training_utils.infer_steps_for_dataset(
132          data, steps_per_epoch, epochs=epochs, steps_name=steps_name)
133
134  # Convert to a format that supports `next(generator)`.
135  generator, steps_per_epoch = convert_to_generator_like(
136      data,
137      steps_per_epoch=steps_per_epoch,
138      batch_size=batch_size,
139      epochs=epochs - initial_epoch,
140      shuffle=shuffle)
141
142  do_validation = validation_data is not None
143  is_sequence = isinstance(generator, data_utils.Sequence)
144  _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
145                      steps_per_epoch, validation_data, validation_steps, mode,
146                      kwargs)
147
148  batch_function = _make_execution_function(
149      model, mode, class_weight=class_weight)
150
151  # Create the queue for the generator.
152  enqueuer = None
153  if not is_dataset:
154    generator, enqueuer = _make_enqueued_generator(
155        generator,
156        workers=workers,
157        use_multiprocessing=use_multiprocessing,
158        max_queue_size=max_queue_size,
159        shuffle=shuffle)
160
161  num_samples_or_steps, use_steps = _get_num_samples_or_steps(
162      data, steps_per_epoch)
163
164  count_mode = 'steps' if use_steps else 'samples'
165  callbacks = cbks.configure_callbacks(
166      callbacks,
167      model,
168      do_validation=do_validation,
169      epochs=epochs,
170      steps_per_epoch=steps_per_epoch,
171      batch_size=batch_size,
172      samples=num_samples_or_steps,
173      verbose=0,  # Handle ProgBar as part of Callbacks once hooks are ready.
174      mode=mode)
175  # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
176  progbar = training_utils.get_progbar(model, count_mode)
177  progbar.params = callbacks.params
178  progbar.params['verbose'] = verbose
179
180  if mode == ModeKeys.PREDICT:
181    aggregator = training_utils.OutputsAggregator(True, steps_per_epoch)
182  else:
183    aggregator = training_utils.MetricsAggregator(True, steps_per_epoch)
184
185  should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
186  if should_set_learning_phase:
187    old_learning_phase = backend.learning_phase()
188    backend.set_eager_learning_phase(1 if mode == ModeKeys.TRAIN else 0)
189
190  callbacks.model.stop_training = False
191  callbacks._call_begin_hook(mode)
192  progbar.on_train_begin()
193  for epoch in range(initial_epoch, epochs):
194    if callbacks.model.stop_training:
195      break
196
197    # Setup work for each epoch.
198    model.reset_metrics()
199    epoch_logs = {}
200    if mode == ModeKeys.TRAIN:
201      callbacks.on_epoch_begin(epoch, epoch_logs)
202    progbar.on_epoch_begin(epoch, epoch_logs)
203
204    if steps_per_epoch is None:
205      # Loop over dataset until `OutOfRangeError` is raised.
206      target_steps = np.inf
207    else:
208      # Loop over dataset for the specified number of steps.
209      target_steps = steps_per_epoch
210
211    step = 0
212    while step < target_steps:
213      batch_data = _get_next_batch(generator, mode)
214      if batch_data is None:
215        if is_dataset:
216          # The dataset passed by the user ran out of batches.
217          # Now we know the cardinality of the dataset.
218          # If steps_per_epoch was specified, then running out of data is
219          # unexpected, so we stop training and inform the user.
220          if steps_per_epoch:
221            callbacks.model.stop_training = True
222            logging.warning(
223                'Your dataset ran out of data; interrupting training. '
224                'Make sure that your dataset can generate at least '
225                '`%s * epochs` batches (in this case, %d batches). '
226                'You may need to use the repeat() function when '
227                'building your dataset.'
228                % (steps_name, steps_per_epoch * epochs))
229          elif step > 0:
230            steps_per_epoch = step
231            aggregator.num_samples_or_steps = steps_per_epoch
232            if mode == ModeKeys.TRAIN:
233              progbar.params['steps'] = steps_per_epoch
234              progbar.progbar.target = steps_per_epoch
235        else:
236          # We ran out of batches while the user passed an iterator (legacy).
237          callbacks.model.stop_training = True
238          logging.warning(
239              'Your dataset iterator ran out of data; '
240              'interrupting training. Make sure that your iterator '
241              'can generate at least `%s * epochs` '
242              'batches (in this case, %d batches). You may need to'
243              'use the repeat() function when building your '
244              'dataset.' % (steps_name, steps_per_epoch * epochs))
245        break
246
247      # `batch_size` used for validation data if validation
248      # data is NumPy/EagerTensors.
249      batch_size = int(nest.flatten(batch_data)[0].shape[0])
250
251      # Callbacks batch begin.
252      batch_logs = {'batch': step, 'size': batch_size}
253      callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
254      progbar.on_batch_begin(step, batch_logs)
255
256      is_deferred = not model._is_compiled
257      batch_outs = batch_function(*batch_data)
258      if not isinstance(batch_outs, list):
259        batch_outs = [batch_outs]
260
261      if step == 0:
262        aggregator.create(batch_outs)
263
264        if is_deferred:
265          # Set callbacks params. We do this here when model is compiled only
266          # in the first iteration of this loop (deferred build scenario).
267          cbks.set_callback_parameters(
268              callbacks,
269              model,
270              do_validation=do_validation,
271              batch_size=batch_size,
272              epochs=epochs,
273              steps_per_epoch=steps_per_epoch,
274              samples=num_samples_or_steps,
275              verbose=verbose,
276              mode=mode)
277
278          progbar.params = callbacks.params
279          progbar.params['verbose'] = verbose
280
281      # Aggregate results.
282      aggregator.aggregate(batch_outs)
283
284      # Callbacks batch end.
285      batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
286      callbacks._call_batch_hook(mode, 'end', step, batch_logs)
287      progbar.on_batch_end(step, batch_logs)
288      step += 1
289
290      if callbacks.model.stop_training:
291        break
292
293    aggregator.finalize()
294    results = aggregator.results
295    epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
296    if len(results) == 1:
297      results = results[0]
298
299    # Run the test loop every epoch during training.
300    if (do_validation and
301        training_utils.should_run_validation(validation_freq, epoch) and
302        not callbacks.model.stop_training):
303      val_results = model_iteration(
304          model,
305          validation_data,
306          steps_per_epoch=validation_steps,
307          batch_size=batch_size,
308          class_weight=class_weight,
309          workers=workers,
310          use_multiprocessing=use_multiprocessing,
311          max_queue_size=max_queue_size,
312          callbacks=callbacks,
313          verbose=0,
314          mode=ModeKeys.TEST,
315          steps_name='validation_steps')
316
317      if not isinstance(val_results, list):
318        val_results = [val_results]
319      epoch_logs = cbks.make_logs(
320          model, epoch_logs, val_results, mode, prefix='val_')
321
322    if mode == ModeKeys.TRAIN:
323      # Epochs only apply to `fit`.
324      callbacks.on_epoch_end(epoch, epoch_logs)
325    progbar.on_epoch_end(epoch, epoch_logs)
326
327    # Recreate dataset iterator for the next epoch.
328    if reset_dataset_after_each_epoch and epoch < epochs - 1:
329      generator = dataset_ops.make_one_shot_iterator(original_dataset)
330
331  callbacks._call_end_hook(mode)
332
333  if enqueuer is not None:
334    enqueuer.stop()
335
336  if should_set_learning_phase:
337    backend.set_eager_learning_phase(old_learning_phase)
338
339  if mode == ModeKeys.TRAIN:
340    return model.history
341  return results
342
343
344# Maintain compatibility with the existing names.
345fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
346evaluate_generator = functools.partial(
347    model_iteration, mode=ModeKeys.TEST, shuffle=False)
348predict_generator = functools.partial(
349    model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
350
351
352def _get_next_batch(generator, mode):
353  """Retrieves the next batch of input data."""
354  try:
355    generator_output = next(generator)
356  except (StopIteration, errors.OutOfRangeError):
357    return None
358  if not isinstance(generator_output, tuple):
359    if mode == ModeKeys.PREDICT:
360      # Always wrap in a tuple.
361      return (generator_output,)
362    else:
363      raise ValueError('Output of generator should be '
364                       'a tuple `(x, y, sample_weight)` '
365                       'or `(x, y)`. Found: ' + str(generator_output))
366
367  if len(generator_output) < 1 or len(generator_output) > 3:
368    raise ValueError('Output of generator should be '
369                     'a tuple `(x, y, sample_weight)` '
370                     'or `(x, y)` or (x,). Found: ' + str(generator_output))
371  return generator_output
372
373
374def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
375                        steps_per_epoch, validation_data, validation_steps,
376                        mode, kwargs):
377  """Raises errors if arguments are invalid.
378
379  Arguments:
380    is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence`
381      instance.
382    is_dataset: Boolean, whether data is a dataset instance.
383    use_multiprocessing: Boolean. If `True`, use process-based threading. If
384      unspecified, `use_multiprocessing` will default to `False`. Note that
385      because this implementation relies on multiprocessing, you should not pass
386      non-picklable arguments to the generator as they can't be passed easily to
387      children processes.
388    workers: Integer. Maximum number of processes to spin up when using
389      process-based threading. If unspecified, `workers` will default to 1. If
390      0, will execute the generator on the main thread.
391    steps_per_epoch: Total number of steps (batches of samples) before declaring
392      one epoch finished and starting the next epoch. Ignored with the default
393      value of `None`.
394    validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x,
395      y)` or `(x, y, sample_weights)`) or a generator or
396      `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
397    validation_steps: Total number of steps (batches of samples) before
398      declaring validation finished.
399    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
400    kwargs: Additional arguments for backwards compatibility.
401
402  Raises:
403    ValueError: If `steps_per_epoch` or `validation_steps` are not passed
404      for data types that require them, or if unrecognized keyword
405      arguments are passed.
406  """
407  if not is_sequence and use_multiprocessing and workers > 1:
408    logging.warning(
409        UserWarning('Using a generator with `use_multiprocessing=True`'
410                    ' and multiple workers may duplicate your data.'
411                    ' Please consider using the `keras.utils.Sequence`'
412                    ' class.'))
413
414  if steps_per_epoch is None and not is_dataset:
415    arg_name = 'steps_per_epoch' if mode == ModeKeys.TRAIN else 'steps'
416    raise ValueError('Please specify the number of steps via the '
417                     '`{}` argument.'.format(arg_name))
418
419  val_gen = (
420      data_utils.is_generator_or_sequence(validation_data) or
421      isinstance(validation_data, iterator_ops.EagerIterator))
422  if (val_gen and not isinstance(validation_data, data_utils.Sequence) and
423      not validation_steps):
424    raise ValueError('Please specify the `validation_steps` argument.')
425
426  if any(k != 'steps' for k in kwargs):
427    raise ValueError('Invalid arguments passed: {}'.format(
428        [k for k in kwargs if k != 'steps']))
429
430
431def convert_to_generator_like(data,
432                              batch_size=None,
433                              steps_per_epoch=None,
434                              epochs=1,
435                              shuffle=False):
436  """Make a generator out of NumPy or EagerTensor inputs.
437
438  Arguments:
439    data: Either a generator or `keras.utils.data_utils.Sequence` object or
440      `Dataset` or `EagerIterator` or a {1,2,3}-tuple of NumPy arrays or
441      EagerTensors. If a tuple, the elements represent `(x, y, sample_weights)`
442      and may be `None` or `[None]`.
443    batch_size: Used when creating a generator out of tuples of NumPy arrays or
444      EagerTensors.
445    steps_per_epoch: Steps of the generator to run each epoch. If `None` the
446      number of steps will be read from the data (for
447      `keras.utils.data_utils.Sequence` types).
448    epochs: Total number of epochs to run.
449    shuffle: Whether the data should be shuffled.
450
451  Returns:
452    - Generator or `keras.utils.data_utils.Sequence` or EagerIterator.
453
454  Raises:
455    - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
456      inputs.
457  """
458  if isinstance(data, tuple):
459    # Scrub `Nones` that might have been passed for `targets`, `sample_weights`.
460    data = tuple(
461        ele for ele in data if not all(e is None for e in nest.flatten(ele)))
462    if len(data) == 1:
463      data = data[0]
464
465  if data_utils.is_generator_or_sequence(data) or isinstance(
466      data, iterator_ops.EagerIterator):
467    if isinstance(data, data_utils.Sequence):
468      if steps_per_epoch is None:
469        steps_per_epoch = len(data)
470    return data, steps_per_epoch
471  if isinstance(data, dataset_ops.DatasetV2):
472    return dataset_ops.make_one_shot_iterator(data), steps_per_epoch
473
474  # Create generator from NumPy or EagerTensor Input.
475  num_samples = int(nest.flatten(data)[0].shape[0])
476  if batch_size is None:
477    raise ValueError('You must specify `batch_size`')
478  steps_per_epoch = int(math.ceil(num_samples / batch_size))
479
480  def _gen(data):
481    """Makes a generator out of a structure of NumPy/EagerTensors."""
482    index_array = np.arange(num_samples)
483    for _ in range(epochs):
484      if shuffle:
485        np.random.shuffle(index_array)
486      batches = generic_utils.make_batches(num_samples, batch_size)
487      for (batch_start, batch_end) in batches:
488        batch_ids = index_array[batch_start:batch_end]
489        flat_batch_data = training_utils.slice_arrays(
490            nest.flatten(data), batch_ids, contiguous=(not shuffle))
491        yield nest.pack_sequence_as(data, flat_batch_data)
492
493  return _gen(data), steps_per_epoch
494
495
496def _make_enqueued_generator(generator,
497                             workers=1,
498                             use_multiprocessing=False,
499                             max_queue_size=10,
500                             shuffle=False):
501  """Create a buffered queue of next elements of the generator."""
502  is_sequence = isinstance(generator, data_utils.Sequence)
503  enqueuer = None
504  if workers > 0:
505    if is_sequence:
506      enqueuer = data_utils.OrderedEnqueuer(
507          generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle)
508    else:
509      enqueuer = data_utils.GeneratorEnqueuer(
510          generator, use_multiprocessing=use_multiprocessing)
511    enqueuer.start(workers=workers, max_queue_size=max_queue_size)
512    output_generator = enqueuer.get()
513  else:
514    if is_sequence:
515      output_generator = data_utils.iter_sequence_infinite(generator)
516    else:
517      output_generator = generator
518  return output_generator, enqueuer
519
520
521def _make_execution_function(model, mode, class_weight=None):
522  """Makes function to run one step of model execution."""
523  if mode == ModeKeys.TRAIN:
524    f = functools.partial(model.train_on_batch, class_weight=class_weight)
525  elif mode == ModeKeys.TEST:
526    f = model.test_on_batch
527  else:
528    # Match signature of other modes to allow
529    # 1, 2, or 3-tuples from generator
530    def predict_on_batch(x, y=None, sample_weights=None):  # pylint: disable=unused-argument
531      return model.predict_on_batch(x)
532
533    f = predict_on_batch
534
535  # Maintain stateful metrics across batch-level calls.
536  if mode != ModeKeys.PREDICT:
537    f = functools.partial(f, reset_metrics=False)
538
539  return f
540
541
542def _get_num_samples_or_steps(data, steps_per_epoch):
543  """Returns number of samples or steps, and whether to use steps count mode."""
544  flat_inputs = nest.flatten(data)
545  if hasattr(flat_inputs[0], 'shape'):
546    return int(flat_inputs[0].shape[0]), False
547  return steps_per_epoch, True
548