• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to plain array data.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23
24import numpy as np
25
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.eager import context
29from tensorflow.python.framework import errors
30from tensorflow.python.keras import backend as K
31from tensorflow.python.keras import callbacks as cbks
32from tensorflow.python.keras.engine import distributed_training_utils
33from tensorflow.python.keras.engine import training_utils
34from tensorflow.python.keras.utils.generic_utils import make_batches
35from tensorflow.python.keras.utils.generic_utils import slice_arrays
36from tensorflow.python.keras.utils.mode_keys import ModeKeys
37from tensorflow.python.platform import tf_logging as logging
38
39try:
40  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
41except ImportError:
42  issparse = None
43
44
45def model_iteration(model,
46                    inputs,
47                    targets=None,
48                    sample_weights=None,
49                    batch_size=None,
50                    epochs=1,
51                    verbose=1,
52                    callbacks=None,
53                    val_inputs=None,
54                    val_targets=None,
55                    val_sample_weights=None,
56                    shuffle=True,
57                    initial_epoch=0,
58                    steps_per_epoch=None,
59                    validation_steps=None,
60                    validation_freq=1,
61                    mode=ModeKeys.TRAIN,
62                    validation_in_fit=False,
63                    prepared_feed_values_from_dataset=False,
64                    steps_name='steps',
65                    **kwargs):
66  """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.
67
68  Arguments:
69      model: Keras Model instance.
70      inputs: Either a list or dictionary of arrays, or a dataset instance.
71      targets: List/dictionary of input arrays.
72      sample_weights: Optional list of sample weight arrays.
73      batch_size: Integer batch size or None if unknown.
74      epochs: Number of times to iterate over the data
75      verbose: Verbosity mode, 0, 1 or 2
76      callbacks: List of callbacks to be called during training
77      val_inputs: Either a list or dictionary of arrays, or a dataset instance.
78      val_targets: List/dictionary of target arrays.
79      val_sample_weights: Optional list of sample weight arrays.
80      shuffle: Whether to shuffle the data at the beginning of each epoch
81        concatenation of list the display names of the outputs of `f` and the
82        list of display names of the outputs of `f_val`.
83      initial_epoch: Epoch at which to start training (useful for resuming a
84        previous training run)
85      steps_per_epoch: Total number of steps (batches of samples) before
86        declaring one epoch finished and starting the next epoch. Ignored with
87        the default value of `None`.
88      validation_steps: Number of steps to run validation for (only if doing
89        validation from data tensors). Ignored with the default value of `None`.
90      validation_freq: Only relevant if validation data is provided. Integer or
91        `collections.Container` instance (e.g. list, tuple, etc.). If an
92        integer, specifies how many training epochs to run before a new
93        validation run is performed, e.g. `validation_freq=2` runs
94        validation every 2 epochs. If a Container, specifies the epochs on
95        which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
96        validation at the end of the 1st, 2nd, and 10th epochs.
97      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
98      validation_in_fit: if true, then this method is invoked from within
99        training iteration (for validation). In the case where `val_inputs` is a
100        dataset, this flag indicates that its iterator and feed values are
101        already created so should properly reuse resources.
102      prepared_feed_values_from_dataset: if True, `inputs` is a list of feed
103        tensors returned from `_prepare_feed_values` call on the validation
104        dataset, so do not call it again on `inputs`. Should only be used for
105        inline validation (i.e., only if `validation_in_fit` is also True).
106      steps_name: The string name of the steps argument, either `steps`,
107        `validation_steps`, or `steps_per_epoch`. Only used for error message
108        formatting.
109      **kwargs: Additional arguments for backwards compatibility.
110
111  Returns:
112      - In TRAIN mode: `History` object.
113      - In TEST mode: Evaluation metrics.
114      - In PREDICT mode: Outputs of the Model called on inputs.
115
116  Raises:
117      ValueError: in case of invalid arguments.
118  """
119  # Backwards compatibility.
120  if 'steps' in kwargs:
121    steps_per_epoch = kwargs.pop('steps')
122  if kwargs:
123    raise TypeError('Unknown arguments: %s' % (kwargs,))
124
125  # In case we were passed a dataset, we extract symbolic tensors from it.
126  reset_dataset_after_each_epoch = False
127  input_iterator = None
128  is_dataset = isinstance(inputs,
129                          (dataset_ops.DatasetV1, dataset_ops.DatasetV2))
130  # TODO(fchollet): consider moving `steps_per_epoch` inference to
131  # _standardize_user_data and set reset_dataset_after_each_epoch as an
132  # attribute on the dataset instance.
133  if is_dataset:
134    if steps_per_epoch is None:
135      reset_dataset_after_each_epoch = True
136      steps_per_epoch = training_utils.infer_steps_for_dataset(
137          inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name)
138    input_iterator = _get_iterator(inputs, model._distribution_strategy)
139
140  if mode == ModeKeys.TRAIN:
141    _print_train_info(inputs, val_inputs, steps_per_epoch, verbose)
142
143  # Enter DistributionStrategy scope.
144  if model._distribution_strategy:
145    scope = distributed_training_utils.distributed_scope(
146        strategy=model._distribution_strategy,
147        learning_phase=(1 if mode == ModeKeys.TRAIN else 0))
148    scope.__enter__()
149
150  # Get step function and loop type.
151  f = _make_execution_function(model, mode)
152  use_steps = is_dataset or steps_per_epoch is not None
153  do_validation = val_inputs is not None
154
155  # Convert Eager Tensors to NumPy arrays to support batching/shuffling.
156  inputs, targets, sample_weights = training_utils. \
157      convert_eager_tensors_to_numpy((inputs, targets, sample_weights))
158
159  # Prepare input data.
160  inputs = input_iterator or inputs
161  if validation_in_fit and prepared_feed_values_from_dataset:
162    # When invoking validation in training loop, avoid creating iterator and
163    # list of feed values for the same validation dataset multiple times (which
164    # essentially would call `iterator.get_next()` that slows down execution and
165    # leads to OOM errors eventually.
166    ins = inputs
167  else:
168    ins = _prepare_feed_values(model, inputs, targets, sample_weights, mode)
169  if not is_dataset:
170    num_samples_or_steps = _get_num_samples_or_steps(ins, batch_size,
171                                                     steps_per_epoch)
172  else:
173    num_samples_or_steps = steps_per_epoch
174
175  # Prepare validation data. Hold references to the iterator and the input list
176  # to properly reinitialize and reuse in multiple validation passes.
177  val_iterator = None
178  if isinstance(val_inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
179    if validation_steps is None:
180      # Because we pass an iterator feed instead of a Dataset to the eval
181      # model_iteration() call, it will not trigger the dataset-input path
182      # that determines the number of steps required. To avoid this issue,
183      # set validation_steps here if validation_steps is None.
184      validation_steps = training_utils.infer_steps_for_dataset(
185          val_inputs,
186          validation_steps,
187          epochs=epochs,
188          steps_name='validation_steps')
189    val_iterator = _get_iterator(val_inputs, model._distribution_strategy)
190    val_inputs = _prepare_feed_values(
191        model, val_iterator, val_targets, val_sample_weights, ModeKeys.TEST)
192
193  # Configure callbacks.
194  count_mode = 'steps' if use_steps else 'samples'
195  callbacks = cbks.configure_callbacks(
196      callbacks,
197      model,
198      do_validation=do_validation,
199      batch_size=batch_size,
200      epochs=epochs,
201      steps_per_epoch=steps_per_epoch,
202      samples=num_samples_or_steps,
203      verbose=0,  # Handle ProgBarLogger separately in this loop.
204      mode=mode)
205  # TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
206  progbar = training_utils.get_progbar(model, count_mode)
207  progbar.params = callbacks.params
208  progbar.params['verbose'] = verbose
209
210  # Find beforehand arrays that need sparse-to-dense conversion.
211  if issparse is not None and not use_steps:
212    indices_for_conversion_to_dense = []
213    feed = _get_model_feed(model, mode)
214    for i, (input_data, feed_tensor) in enumerate(zip(ins, feed)):
215      if issparse(input_data) and not K.is_sparse(feed_tensor):
216        indices_for_conversion_to_dense.append(i)
217
218  # Select aggregation method.
219  if mode == ModeKeys.PREDICT:
220    aggregator = training_utils.OutputsAggregator(use_steps,
221                                                  num_samples_or_steps)
222  else:
223    aggregator = training_utils.MetricsAggregator(use_steps,
224                                                  num_samples_or_steps)
225
226  if model._compile_distribution:
227    distributed_training_utils._copy_weights_to_distributed_model(model, mode)
228
229  callbacks.model.stop_training = False
230  callbacks._call_begin_hook(mode)
231  progbar.on_train_begin()
232
233  for epoch in range(initial_epoch, epochs):
234    if callbacks.model.stop_training:
235      break
236
237    # Setup work for each epoch
238    epoch_logs = {}
239    model.reset_metrics()
240    if mode == ModeKeys.TRAIN:
241      callbacks.on_epoch_begin(epoch, epoch_logs)
242    progbar.on_epoch_begin(epoch, epoch_logs)
243
244    if use_steps:
245      # Step-wise loop.
246      if steps_per_epoch is None:
247        # Loop over dataset until `OutOfRangeError` is raised.
248        target_steps = np.inf
249      else:
250        # Loop over dataset for the specified number of steps.
251        target_steps = steps_per_epoch
252
253      step = 0
254      while step < target_steps:
255        batch_logs = {'batch': step, 'size': 1}
256        callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
257        progbar.on_batch_begin(step, batch_logs)
258
259        # Get outputs.
260        try:
261          # `ins` can be callable in DistributionStrategy + eager case.
262          actual_inputs = ins() if callable(ins) else ins
263          batch_outs = f(actual_inputs)
264        except errors.OutOfRangeError:
265          if is_dataset:
266            # The dataset passed by the user ran out of batches.
267            # Now we know the cardinality of the dataset.
268            # If steps_per_epoch was specified, then running out of data is
269            # unexpected, so we stop training and inform the user.
270            if steps_per_epoch:
271              callbacks.model.stop_training = True
272              logging.warning(
273                  'Your dataset ran out of data; interrupting training. '
274                  'Make sure that your dataset can generate at least '
275                  '`%s * epochs` batches (in this case, %d batches). '
276                  'You may need to use the repeat() function when '
277                  'building your dataset.'
278                  % (steps_name, steps_per_epoch * epochs))
279            elif step > 0:
280              steps_per_epoch = step
281              aggregator.num_samples_or_steps = steps_per_epoch
282              if mode == ModeKeys.TRAIN:
283                progbar.params['steps'] = steps_per_epoch
284                progbar.progbar.target = steps_per_epoch
285          else:
286            # We ran out of batches while the user passed an iterator (legacy).
287            callbacks.model.stop_training = True
288            logging.warning(
289                'Your dataset iterator ran out of data; '
290                'interrupting training. Make sure that your iterator '
291                'can generate at least `%s * epochs` '
292                'batches (in this case, %d batches). You may need to'
293                'use the repeat() function when building your '
294                'dataset.' % (steps_name, steps_per_epoch * epochs))
295          break
296
297        if not isinstance(batch_outs, list):
298          batch_outs = [batch_outs]
299
300        if model._distribution_strategy:
301          batch_outs = distributed_training_utils._per_device_aggregate_batch(
302              batch_outs, model, mode)
303
304        # Aggregate results.
305        if step == 0:
306          aggregator.create(batch_outs)
307        aggregator.aggregate(batch_outs)
308
309        # Callbacks batch end.
310        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
311        callbacks._call_batch_hook(mode, 'end', step, batch_logs)
312        progbar.on_batch_end(step, batch_logs)
313        step += 1
314
315        if callbacks.model.stop_training:
316          break
317    else:
318      # Sample-wise loop.
319      index_array = np.arange(num_samples_or_steps)
320      if shuffle == 'batch':
321        index_array = training_utils.batch_shuffle(index_array, batch_size)
322      elif shuffle:
323        np.random.shuffle(index_array)
324      batches = make_batches(num_samples_or_steps, batch_size)
325
326      for batch_index, (batch_start, batch_end) in enumerate(batches):
327        batch_ids = index_array[batch_start:batch_end]
328
329        # Slice into a batch.
330        try:
331          if ins and isinstance(ins[-1], int):
332            # Do not slice the training phase flag.
333            ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
334          else:
335            ins_batch = slice_arrays(ins, batch_ids)
336        except TypeError:
337          raise TypeError('TypeError while preparing batch. '
338                          'If using HDF5 input data, '
339                          'pass shuffle="batch".')
340
341        # Sparse to dense conversion.
342        if issparse is not None:
343          for i in indices_for_conversion_to_dense:
344            ins_batch[i] = ins_batch[i].toarray()
345
346        # Callbacks batch_begin.
347        batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
348        callbacks._call_batch_hook(mode, 'begin', batch_index, batch_logs)
349        progbar.on_batch_begin(batch_index, batch_logs)
350
351        # Get outputs.
352        batch_outs = f(ins_batch)
353        if not isinstance(batch_outs, list):
354          batch_outs = [batch_outs]
355
356        # Aggregate results.
357        if batch_index == 0:
358          aggregator.create(batch_outs)
359        aggregator.aggregate(batch_outs, batch_start, batch_end)
360
361        # Callbacks batch end.
362        batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
363        callbacks._call_batch_hook(mode, 'end', batch_index, batch_logs)
364        progbar.on_batch_end(batch_index, batch_logs)
365
366        if callbacks.model.stop_training:
367          break
368
369    aggregator.finalize()
370    results = aggregator.results
371    epoch_logs = cbks.make_logs(model, epoch_logs, results, mode)
372    if len(results) == 1:
373      results = results[0]
374
375    # Run the test loop every `validation_freq` epochs during training.
376    if (do_validation and
377        training_utils.should_run_validation(validation_freq, epoch) and
378        not callbacks.model.stop_training):
379
380      if model._compile_distribution:
381        # Since we create a new clone from the original model we need to copy
382        # the weights back to the original model before we can run validation.
383        distributed_training_utils._copy_weights_to_original_model(
384            model, ModeKeys.TRAIN)
385
386      val_results = model_iteration(
387          model,
388          val_inputs,
389          targets=val_targets,
390          sample_weights=val_sample_weights,
391          batch_size=batch_size,
392          steps_per_epoch=validation_steps,
393          callbacks=callbacks,
394          verbose=0,
395          mode=ModeKeys.TEST,
396          validation_in_fit=True,
397          prepared_feed_values_from_dataset=(val_iterator is not None),
398          steps_name='validation_steps')
399      if not isinstance(val_results, list):
400        val_results = [val_results]
401      epoch_logs = cbks.make_logs(
402          model, epoch_logs, val_results, mode, prefix='val_')
403      if val_iterator and epoch < epochs - 1:
404        _reinitialize_iterator(val_iterator, model._distribution_strategy)
405
406    if mode == ModeKeys.TRAIN:
407      # Epochs only apply to `fit`.
408      callbacks.on_epoch_end(epoch, epoch_logs)
409    progbar.on_epoch_end(epoch, epoch_logs)
410
411    # Reinitialize dataset iterator for the next epoch.
412    if reset_dataset_after_each_epoch and epoch < epochs - 1:
413      _reinitialize_iterator(input_iterator, model._distribution_strategy)
414
415  callbacks._call_end_hook(mode)
416
417  if model._distribution_strategy:
418    if model._compile_distribution:
419      # TODO(priyag, psv): Copy back metrics to the original model as well?
420      distributed_training_utils._copy_weights_to_original_model(model, mode)
421    scope.__exit__(None, None, None)
422
423  if mode == ModeKeys.TRAIN:
424    return model.history
425  return results
426
427
428def _get_model_feed(model, mode):
429  if mode == ModeKeys.PREDICT:
430    feed = model._feed_inputs
431  else:
432    feed = (
433        model._feed_inputs + model._feed_targets + model._feed_sample_weights)
434  return feed
435
436
437def _print_train_info(inputs, val_inputs, steps_per_epoch, verbose):
438  if (val_inputs and steps_per_epoch is None and verbose and inputs and
439      hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
440    print('Train on %d samples, validate on %d samples' %
441          (inputs[0].shape[0], val_inputs[0].shape[0]))
442
443
444def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch):
445  """Returns total number of samples (when training in batch mode) or steps."""
446  if steps_per_epoch:
447    return steps_per_epoch
448  return training_utils.check_num_samples(ins, batch_size, steps_per_epoch,
449                                          'steps_per_epoch')
450
451
452def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
453  """Prepare feed values to the model execution function.
454
455  Arguments:
456    model: Model to prepare feed values for.
457    inputs: List or dict of model inputs.
458    targets: Optional list of model targets.
459    sample_weights: Optional list of sample weight arrays.
460    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
461
462  Returns:
463    Feed values for the model in the given mode.
464  """
465  if model._distribution_strategy:
466    if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
467      inputs = distributed_training_utils.get_iterator(
468          inputs, model._distribution_strategy)
469
470    def get_distributed_inputs():
471      return distributed_training_utils._prepare_feed_values(
472          model, inputs, targets, sample_weights, mode)
473
474    # In the eager case, we want to call the input method per step, so return
475    # a lambda from here that can be called. Note that this is applicable only
476    # in Distribution Strategy case as it follows the same code path for both
477    # eager and graph modes.
478    # TODO(priyag,omalleyt): Either we should move the training DS with
479    # EagerIterator to use training_generator code path, or figure out how to
480    # set a symbolic Iterator out of a Dataset when in eager mode.
481    if context.executing_eagerly():
482      return get_distributed_inputs
483    else:
484      return get_distributed_inputs()
485
486  if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
487                         iterator_ops.Iterator)):
488    inputs, targets, sample_weights = model._standardize_user_data(
489        inputs,
490        extract_tensors_from_dataset=True)
491
492  inputs = training_utils.ModelInputs(inputs).as_list()
493  targets = targets or []
494  sample_weights = sample_weights or []
495  ins = inputs + targets + sample_weights
496  if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(),
497                                               int):
498    ins += [True]  # Add learning phase value.
499  return ins
500
501
502def _get_iterator(inputs, distribution_strategy=None):
503  if distribution_strategy:
504    return distributed_training_utils.get_iterator(
505        inputs, distribution_strategy)
506  return training_utils.get_iterator(inputs)
507
508
509def _reinitialize_iterator(iterator, distribution_strategy=None):
510  if distribution_strategy:
511    distributed_training_utils.initialize_iterator(
512        iterator, distribution_strategy)
513  else:
514    training_utils.initialize_iterator(iterator)
515
516
517def _make_execution_function(model, mode):
518  """Makes function to run one step of model execution."""
519  if model._distribution_strategy:
520    return distributed_training_utils._make_execution_function(model, mode)
521  return model._make_execution_function(mode)
522
523
524# For backwards compatibility for internal users of these loops.
525fit_loop = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
526test_loop = functools.partial(
527    model_iteration, mode=ModeKeys.TEST, shuffle=False)
528predict_loop = functools.partial(
529    model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
530