• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to distributed training.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.data.experimental.ops import batching
25from tensorflow.python.distribute import input_lib
26from tensorflow.python.distribute import reduce_util as ds_reduce_util
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.keras import backend as K
31from tensorflow.python.keras import callbacks as cbks
32from tensorflow.python.keras.engine import distributed_training_utils
33from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
34from tensorflow.python.keras.engine import training_arrays
35from tensorflow.python.keras.engine import training_utils
36from tensorflow.python.keras.utils.generic_utils import Progbar
37from tensorflow.python.keras.utils.mode_keys import ModeKeys
38from tensorflow.python.ops import array_ops
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import nest
41
42
43def fit_distributed(model,
44                    x=None,
45                    y=None,
46                    batch_size=None,
47                    epochs=1,
48                    verbose=1,
49                    callbacks=None,
50                    validation_split=0.,
51                    validation_data=None,
52                    shuffle=True,
53                    class_weight=None,
54                    sample_weight=None,
55                    initial_epoch=0,
56                    steps_per_epoch=None,
57                    validation_steps=None,
58                    validation_freq=1):
59  """Fit loop for Distribution Strategies."""
60  distributed_training_utils.validate_callbacks(callbacks, model.optimizer)
61  distributed_training_utils.validate_inputs(
62      x, y, model._distribution_strategy)
63
64  first_x_value = nest.flatten(x)[0]
65  if isinstance(first_x_value, np.ndarray):
66    # Until support for partial batch is implemented across all
67    # functions and distribution strategy, we pass `mode` to selectively
68    # relax the costraint to consume all the training samples.
69    steps_per_epoch, batch_size = (
70        distributed_training_utils.get_input_params(
71            model._distribution_strategy, first_x_value, steps_per_epoch,
72            batch_size, mode=ModeKeys.TRAIN))
73  batch_size = model._validate_or_infer_batch_size(
74      batch_size, steps_per_epoch, x)
75  dataset = model._distribution_standardize_user_data(
76      x, y,
77      sample_weight=sample_weight,
78      class_weight=class_weight,
79      batch_size=batch_size,
80      validation_split=validation_split,
81      shuffle=shuffle,
82      repeat=True)
83
84  val_dataset = None
85  if validation_data:
86    val_x, val_y, val_sample_weights = model._unpack_validation_data(
87        validation_data)
88    distributed_training_utils.validate_inputs(
89        val_x, val_y, model._distribution_strategy)
90    first_valx_value = nest.flatten(val_x)[0]
91    if isinstance(first_valx_value, np.ndarray):
92      validation_steps, _ = distributed_training_utils.get_input_params(
93          model._distribution_strategy, first_valx_value, validation_steps,
94          batch_size)
95    val_dataset = model._distribution_standardize_user_data(
96        val_x, val_y,
97        sample_weight=val_sample_weights,
98        class_weight=None,
99        batch_size=batch_size,
100        validation_split=validation_split,
101        shuffle=shuffle)
102  elif validation_split:
103    raise ValueError('validation_split argument is not supported with '
104                     'distribution strategies.')
105
106  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
107    return experimental_tpu_fit_loop(
108        model,
109        dataset,
110        epochs=epochs,
111        verbose=verbose,
112        callbacks=callbacks,
113        val_dataset=val_dataset,
114        initial_epoch=initial_epoch,
115        steps_per_epoch=steps_per_epoch,
116        validation_steps=validation_steps,
117        validation_freq=validation_freq)
118  else:
119    return training_arrays.fit_loop(
120        model,
121        dataset,
122        batch_size=batch_size,
123        epochs=epochs,
124        verbose=verbose,
125        callbacks=callbacks,
126        val_inputs=val_dataset,
127        shuffle=shuffle,
128        initial_epoch=initial_epoch,
129        steps_per_epoch=steps_per_epoch,
130        validation_steps=validation_steps,
131        validation_freq=validation_freq,
132        steps_name='steps_per_epoch')
133
134
135def evaluate_distributed(model,
136                         x=None,
137                         y=None,
138                         batch_size=None,
139                         verbose=1,
140                         sample_weight=None,
141                         steps=None,
142                         callbacks=None):
143  """Evaluate loop for Distribution Strategies."""
144  distributed_training_utils.validate_inputs(x, y, model._distribution_strategy)
145  first_x_value = nest.flatten(x)[0]
146  if isinstance(first_x_value, np.ndarray):
147    steps, batch_size = distributed_training_utils.get_input_params(
148        model._distribution_strategy, first_x_value, steps, batch_size)
149  batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
150  dataset = model._distribution_standardize_user_data(
151      x, y,
152      sample_weight=sample_weight,
153      batch_size=batch_size)
154
155  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
156    return experimental_tpu_test_loop(
157        model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
158  else:
159    return training_arrays.test_loop(
160        model,
161        inputs=dataset,
162        batch_size=batch_size,
163        verbose=verbose,
164        steps=steps,
165        callbacks=callbacks)
166
167
168def predict_distributed(model,
169                        x=None,
170                        batch_size=None,
171                        verbose=0,
172                        steps=None,
173                        callbacks=None):
174  """Predict loop for Distribution Strategies."""
175  distributed_training_utils.validate_inputs(
176      x, None, model._distribution_strategy, allow_partial_batch=True)
177  first_x_value = nest.flatten(x)[0]
178  if isinstance(first_x_value, np.ndarray):
179    steps, batch_size = distributed_training_utils.get_input_params(
180        model._distribution_strategy, first_x_value, steps,
181        batch_size, mode=ModeKeys.PREDICT)
182  batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
183  dataset = model._distribution_standardize_user_data(
184      x,
185      batch_size=batch_size,
186      allow_partial_batch=True)
187  if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
188    return experimental_tpu_predict_loop(
189        model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
190  else:
191    return training_arrays.predict_loop(
192        model,
193        dataset,
194        batch_size=batch_size,
195        verbose=verbose,
196        steps=steps,
197        callbacks=callbacks)
198
199
200def _make_step_fn(model, mode, strategy, output_labels):
201  """Create step fn.
202
203  Arguments:
204    model: a Keras Model instance.
205    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
206    strategy: a `tf.distribute.Strategy` instance.
207    output_labels: the output labels for the step function.
208
209  Returns:
210    A step function to run by `tf.distribute.Strategy`.
211  """
212
213  def _per_device_execution_function(model):
214    exec_func = model._make_execution_function(mode)
215    return (exec_func.inputs, exec_func.outputs, exec_func.updates_op,
216            exec_func.session_kwargs)
217
218  def step_fn(ctx, inputs):
219    """A step fn that returns update ops."""
220    if mode == ModeKeys.PREDICT:
221      targets = None
222    else:
223      inputs, targets = inputs
224
225    if model._compile_distribution:
226      distributed_training_utils.clone_model_on_replicas(
227          model, strategy, mode, inputs=inputs, targets=targets)
228    else:
229      distributed_training_utils._build_distributed_network(
230          model, strategy, mode, inputs, targets)
231
232    (grouped_inputs, grouped_outputs, grouped_updates,
233     grouped_session_args) = strategy.extended.call_for_each_replica(
234         _per_device_execution_function,
235         args=(distributed_training_utils.get_distributed_model(model, mode),))
236    (all_inputs, all_outputs, all_updates,
237     all_session_args) = distributed_training_utils.unwrap_values(
238         strategy, grouped_inputs, grouped_outputs, grouped_updates,
239         grouped_session_args)
240    combined_fn = K.function(
241        all_inputs,
242        all_outputs,
243        updates=all_updates,
244        name='distributed_' + str(mode) + '_function',
245        **all_session_args)
246
247    for label, output in zip(output_labels, combined_fn.outputs):
248      if mode == ModeKeys.PREDICT:
249        ctx.set_last_step_output(label, output)
250      else:
251        if label == 'loss':
252          reduce_op = ds_reduce_util.ReduceOp.SUM
253        else:
254          # We reduce all other metrics using mean for now. This is temporary
255          # workaround until new metrics are in place.
256          reduce_op = ds_reduce_util.ReduceOp.MEAN
257        ctx.set_last_step_output(label, output, reduce_op)
258
259    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
260    # feed_dict, session kwargs, run options, run_metadata for now. These should
261    # be handled appropriately
262    return combined_fn.updates_op
263
264  return step_fn
265
266
267def experimental_tpu_fit_loop(model,
268                              dataset,
269                              epochs=100,
270                              verbose=1,
271                              callbacks=None,
272                              initial_epoch=0,
273                              steps_per_epoch=None,
274                              val_dataset=None,
275                              validation_steps=None,
276                              validation_freq=1):
277  """Fit loop for training with TPU DistributionStrategy.
278
279  Arguments:
280      model: Keras Model instance.
281      dataset: Dataset that returns inputs and targets
282      epochs: Number of times to iterate over the data
283      verbose: Integer, Verbosity mode, 0, 1 or 2
284      callbacks: List of callbacks to be called during training
285      initial_epoch: Epoch at which to start training
286          (useful for resuming a previous training run)
287      steps_per_epoch: Total number of steps (batches of samples)
288          before declaring one epoch finished and starting the
289          next epoch. Ignored with the default value of `None`.
290      val_dataset: Dataset for validation data.
291      validation_steps: Number of steps to run validation for
292          (only if doing validation from data tensors).
293          Ignored with the default value of `None`.
294      validation_freq: Only relevant if validation data is provided. Integer or
295          `collections.Container` instance (e.g. list, tuple, etc.). If an
296          integer, specifies how many training epochs to run before a new
297          validation run is performed, e.g. `validation_freq=2` runs
298          validation every 2 epochs. If a Container, specifies the epochs on
299          which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
300          validation at the end of the 1st, 2nd, and 10th epochs.
301
302  Returns:
303      Returns `None`.
304
305  Raises:
306      ValueError: in case of invalid arguments.
307  """
308  mode = ModeKeys.TRAIN
309  # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
310  current_strategy = model._distribution_strategy
311  iterator = distributed_training_utils.get_iterator(dataset, current_strategy)
312  steps_per_epoch = training_utils.infer_steps_for_dataset(
313      dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
314  if (current_strategy.extended.steps_per_run != 1 and
315      steps_per_epoch is None):
316    raise ValueError('`steps_per_epoch` should be specified when calling '
317                     '`fit` on the model with TPUStrategy when '
318                     '`steps_per_run` != 1 .')
319
320  scope = distributed_training_utils.distributed_scope(
321      strategy=current_strategy, learning_phase=1)
322  scope.__enter__()
323
324  out_labels = model.metrics_names or []
325
326  step_fn = _make_step_fn(model, ModeKeys.TRAIN, current_strategy, out_labels)
327
328  # Add initial dummy values for loss and other metric tensors.
329  initial_loop_values = {}
330  initial_loop_values['loss'] = constant_op.constant(1e7)
331  for name in model.metrics_names[1:]:
332    tensor = model._all_metrics_tensors[name]
333    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
334
335  use_steps = steps_per_epoch is not None
336  if use_steps:
337    iteration_value = min(steps_per_epoch,
338                          current_strategy.extended.steps_per_run)
339  else:
340    iteration_value = current_strategy.extended.steps_per_run
341
342  steps_per_run = K.variable(
343      value=iteration_value,
344      dtype='int32',
345      name='steps_per_run')
346  ctx = current_strategy.extended.experimental_run_steps_on_iterator(
347      step_fn, iterator, iterations=steps_per_run,
348      initial_loop_values=initial_loop_values)
349  train_op = ctx.run_op
350  output_tensors = ctx.last_step_outputs
351
352  do_validation = bool(validation_steps)
353
354  if model._compile_distribution:
355    distributed_training_utils._copy_weights_to_distributed_model(model, mode)
356
357  callbacks = cbks.configure_callbacks(
358      callbacks,
359      model,
360      do_validation=do_validation,
361      epochs=epochs,
362      steps_per_epoch=steps_per_epoch,
363      verbose=verbose,
364      count_mode='steps',
365      mode=mode)
366
367  # Calculate the steps each time on the device.
368  if use_steps:
369    steps_to_run = ([current_strategy.extended.steps_per_run] *
370                    (steps_per_epoch //
371                     current_strategy.extended.steps_per_run))
372    if steps_per_epoch % current_strategy.extended.steps_per_run:
373      steps_to_run.append(
374          steps_per_epoch % current_strategy.extended.steps_per_run)
375    target_steps = len(steps_to_run)
376  else:
377    target_steps = np.inf
378
379  callbacks._call_begin_hook(mode)
380  for epoch in range(initial_epoch, epochs):
381    distributed_training_utils._reset_metrics(model)
382    callbacks.on_epoch_begin(epoch)
383    epoch_logs = {}
384    step_index = 0
385    prev_step_count = None
386    current_step = 0
387    while current_step < target_steps:
388      step_count = steps_to_run[current_step] if use_steps else 1
389      batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
390      callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs)
391      if prev_step_count is None or step_count != prev_step_count:
392        steps_per_run.load(step_count, K.get_session())
393        prev_step_count = step_count
394      try:
395        _, outputs = K.batch_get_value([train_op, output_tensors])
396      except errors.OutOfRangeError:
397        if use_steps:
398          logging.warning('Your dataset iterator ran out of data; '
399                          'interrupting training. Make sure that your dataset '
400                          'can generate at least `steps_per_epoch * epochs` '
401                          'batches (in this case, %d batches).' %
402                          steps_per_epoch * epochs)
403        else:
404          target_steps = current_step
405          logging.info('Dataset iterator ran out of data. Inferring the '
406                       'value of `steps_per_epoch` as %s  .' % target_steps)
407          distributed_training_utils.initialize_iterator(iterator,
408                                                         current_strategy)
409        break
410
411      batch_logs.update(outputs)
412      callbacks._call_batch_hook(mode, 'end', step_index, batch_logs)
413      step_index = step_index + step_count
414      current_step += 1
415
416      if callbacks.model.stop_training:
417        break
418
419    if (do_validation and
420        training_utils.should_run_validation(validation_freq, epoch)):
421      logging.info('Running validation at fit epoch: %s', epoch)
422
423      if model._compile_distribution:
424        # Since we create a new clone from the original model we need to copy
425        # the weights back to the original model before we can run validation.
426        distributed_training_utils._copy_weights_to_original_model(
427            model, ModeKeys.TRAIN)
428
429      val_outs = experimental_tpu_test_loop(  # pylint: disable=undefined-variable
430          model,
431          val_dataset,
432          steps=validation_steps,
433          verbose=verbose,
434          callbacks=callbacks)
435      if not isinstance(val_outs, list):
436        val_outs = [val_outs]
437      # Same labels assumed.
438      for label, val_out in zip(out_labels, val_outs):
439        epoch_logs['val_' + label] = val_out
440
441    callbacks.on_epoch_end(epoch, epoch_logs)
442    if callbacks.model.stop_training:
443      break
444  callbacks._call_end_hook(mode)
445
446  if model._compile_distribution:
447    # Copy the weights back from the replicated model to the original model.
448    distributed_training_utils._copy_weights_to_original_model(
449        model, ModeKeys.TRAIN)
450  scope.__exit__(None, None, None)
451  return model.history
452
453
454def experimental_tpu_test_loop(model,
455                               dataset,
456                               verbose=0,
457                               steps=None,
458                               callbacks=None):
459  """Test loop for evaluating with TPU DistributionStrategy.
460
461  Arguments:
462      model: Keras Model instance.
463      dataset: Dataset for input data.
464      verbose: Integer, Verbosity mode 0 or 1.
465      steps: Total number of steps (batches of samples)
466          before declaring predictions finished.
467          Ignored with the default value of `None`.
468      callbacks: List of callbacks to be called during training
469
470  Returns:
471      Scalar loss (if the model has a single output and no metrics)
472      or list of scalars (if the model has multiple outputs
473      and/or metrics). The attribute `model.metrics_names` will give you
474      the display labels for the outputs.
475  """
476  mode = ModeKeys.TEST
477  current_strategy = model._distribution_strategy
478  iterator = distributed_training_utils.get_iterator(dataset,
479                                                     current_strategy)
480  steps = training_utils.infer_steps_for_dataset(dataset, steps,
481                                                 steps_name='steps')
482
483  scope = distributed_training_utils.distributed_scope(
484      strategy=current_strategy, learning_phase=0)
485  scope.__enter__()
486
487  out_labels = model.metrics_names
488  step_fn = _make_step_fn(model, ModeKeys.TEST, current_strategy, out_labels)
489
490  # Add initial dummy values for loss and other metric tensors.
491  initial_loop_values = {}
492  initial_loop_values['loss'] = constant_op.constant(1e7)
493  for name in model.metrics_names[1:]:
494    tensor = model._all_metrics_tensors[name]
495    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
496
497  # TODO(priyag): Use steps_per_run when we use new metrics as they will
498  # allow handling metric computation at each step using variables.
499  ctx = current_strategy.extended.experimental_run_steps_on_iterator(
500      step_fn, iterator, iterations=1,
501      initial_loop_values=initial_loop_values)
502
503  test_op = ctx.run_op
504  output_tensors = ctx.last_step_outputs
505
506  if verbose == 1:
507    progbar = Progbar(target=steps)
508
509  if model._compile_distribution:
510    distributed_training_utils._copy_weights_to_distributed_model(model, mode)
511
512  distributed_training_utils._reset_metrics(model)
513
514  callbacks = cbks.configure_callbacks(
515      callbacks,
516      model,
517      do_validation=False,
518      epochs=1,
519      steps_per_epoch=steps,
520      verbose=verbose,
521      count_mode='steps',
522      mode=ModeKeys.TEST)
523  callbacks._call_begin_hook(mode)
524
525  outs = [0.] * len(model.metrics_names)
526  if steps is not None:
527    target_steps = steps
528  else:
529    target_steps = np.inf
530
531  current_step = 0
532  while current_step < target_steps:
533    batch_logs = {'batch': current_step, 'size': 1}
534    callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
535    try:
536      _, batch_outs = K.batch_get_value([test_op, output_tensors])
537    except errors.OutOfRangeError:
538      if steps is not None:
539        warning_msg = 'Make sure that your dataset can generate at least '
540        '`steps` batches (in this case, {} batches).'.format(steps)
541      else:
542        warning_msg = 'Number of steps ran: {} steps'.format(current_step)
543
544      logging.warning('Your dataset iterator ran out of data; '
545                      'interrupting evaluation. ' + warning_msg)
546      target_steps = current_step
547      break
548    for i, label in enumerate(model.metrics_names):
549      if i == 0:
550        # Loss is stateless metrics.
551        outs[i] += batch_outs[label]
552      else:
553        # For all stateful metrics, the aggregation is handled by mirrored vars.
554        outs[i] = batch_outs[label]
555
556    batch_logs = cbks.make_logs(model, batch_logs, outs, mode)
557    callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
558    if verbose >= 1:
559      progbar.update(current_step + 1)
560    current_step += 1
561
562  callbacks._call_end_hook(mode)
563
564  scope.__exit__(None, None, None)
565  if len(outs) >= 0:
566    outs[0] /= (target_steps)
567
568  if len(outs) == 1:
569    return outs[0]
570  return outs
571
572
573def experimental_tpu_predict_loop(model,
574                                  dataset,
575                                  verbose=0,
576                                  steps=None,
577                                  callbacks=None):
578  """Predict loop for predicting with TPU DistributionStrategy.
579
580  Arguments:
581      model: Keras Model instance.
582      dataset: Dataset for input data.
583      verbose: Integer, Verbosity mode 0 or 1.
584      steps: Total number of steps (batches of samples)
585          before declaring `_predict_loop` finished.
586          Ignored with the default value of `None`.
587      callbacks: List of callbacks to be called during training
588
589  Returns:
590      Array of predictions (if the model has a single output)
591      or list of arrays of predictions
592      (if the model has multiple outputs).
593  """
594  mode = ModeKeys.PREDICT
595  steps = training_utils.infer_steps_for_dataset(dataset, steps,
596                                                 steps_name='steps')
597  dataset_fully_shaped = (distributed_training_utils.
598                          is_dataset_shape_fully_defined(dataset))
599  padding_handler = None
600  if not dataset_fully_shaped:
601    # TODO(hongjunchoi): Investigate whether operations from
602    # PartialBatchPaddingHandler are unnecessarily pruned out
603    # during graph optimization.
604    padding_handler = padding_util.PartialBatchPaddingHandler(
605        model._feed_output_shapes)
606    batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset)
607    padding_handler.padded_batch_size = batch_size
608    padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
609                                                  padding_handler.update_mask)
610
611    dataset = dataset.map(padding_handler.pad_batch)
612    dataset = dataset.apply(batching.unbatch())
613    # Upon this point, it is guaranteed that the dataset does not
614    # have partial batches. Thus, we set `drop_remainder=True` to
615    # get static shape information about the elements in the dataset.
616    dataset = dataset.batch(batch_size, drop_remainder=True)
617
618    if prefetch_buffer is not None:
619      dataset = dataset.prefetch(prefetch_buffer)
620
621  current_strategy = model._distribution_strategy
622  iterator = distributed_training_utils.get_iterator(dataset, current_strategy)
623
624  scope = distributed_training_utils.distributed_scope(
625      strategy=current_strategy, learning_phase=0)
626  scope.__enter__()
627
628  out_labels = model.output_names
629  step_fn = _make_step_fn(model, ModeKeys.PREDICT, current_strategy, out_labels)
630
631  # Add initial dummy values for outputs.
632  initial_loop_values = {}
633  batch_dimension = distributed_training_utils.get_batch_dimension(iterator)
634  for name, tensor in zip(model.output_names, model.outputs):
635    # TODO(priyag): This is a workaround as we do not know the batch dimension
636    # of the model's output at this point.
637    shape = tensor_shape.TensorShape(tensor.shape.dims)
638    shape.dims = [batch_dimension] + shape.dims[1:]
639    initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)
640
641  # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
642  ctx = current_strategy.extended.experimental_run_steps_on_iterator(
643      step_fn, iterator, iterations=1,
644      initial_loop_values=initial_loop_values)
645
646  predict_op = ctx.run_op
647  output_tensors = ctx.last_step_outputs
648
649  if verbose == 1:
650    progbar = Progbar(target=steps)
651
652  if model._compile_distribution:
653    distributed_training_utils._copy_weights_to_distributed_model(model, mode)
654
655  distributed_training_utils._reset_metrics(model)
656
657  callbacks = cbks.configure_callbacks(
658      callbacks,
659      model,
660      do_validation=False,
661      epochs=1,
662      steps_per_epoch=steps,
663      verbose=verbose,
664      count_mode='steps',
665      mode=mode)
666  callbacks._call_begin_hook(mode)
667
668  # Since we do not know how many samples we will see, we cannot pre-allocate
669  # the returned Numpy arrays. Instead, we store one array per batch seen
670  # and concatenate them upon returning.
671  unconcatenated_outs = [[] for _ in model.outputs]
672  if steps is not None:
673    target_steps = steps
674  else:
675    target_steps = np.inf
676
677  current_step = 0
678  while current_step < target_steps:
679    batch_logs = {'batch': current_step, 'size': 1}
680    callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
681    try:
682      _, batch_outs = K.batch_get_value([predict_op, output_tensors])
683    except errors.OutOfRangeError:
684      if steps is not None:
685        warning_msg = 'Make sure that your dataset can generate at least '
686        '`steps` batches (in this case, {} batches).'.format(steps)
687      else:
688        warning_msg = 'Number of steps ran: {} steps'.format(current_step)
689
690      logging.warning('Your dataset iterator ran out of data; '
691                      'interrupting evaluation. ' + warning_msg)
692      break
693
694    # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
695    for i, label in enumerate(model.output_names):
696      unconcatenated_outs[i].extend(batch_outs[label])
697    batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
698    callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
699    if verbose >= 1:
700      progbar.update(current_step + 1)
701    current_step += 1
702
703  callbacks._call_end_hook(mode)
704
705  scope.__exit__(None, None, None)
706
707  if len(unconcatenated_outs) == 1:
708    prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
709  else:
710    prediction_result = [
711        np.concatenate(unconcatenated_outs[i], axis=0)
712        for i in range(len(unconcatenated_outs))
713    ]
714
715  if padding_handler:
716    prediction_result = padding_handler.apply_mask(prediction_result)
717
718  return prediction_result
719