• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Part of the Keras training engine related to distributed training.
16"""
17# pylint: disable=protected-access
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.distribute import distribute_coordinator as dc
25from tensorflow.python.distribute import distribution_strategy_context
26from tensorflow.python.distribute import input_lib
27from tensorflow.python.distribute import reduce_util as ds_reduce_util
28from tensorflow.python.eager import context
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.keras import backend as K
33from tensorflow.python.keras import callbacks as cbks
34from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils
35from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util
36from tensorflow.python.keras.engine import training_arrays_v1
37from tensorflow.python.keras.engine import training_utils_v1
38from tensorflow.python.keras.utils.generic_utils import Progbar
39from tensorflow.python.keras.utils.mode_keys import ModeKeys
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.platform import tf_logging as logging
43
44
45def _per_replica_execution_function(model, mode):
46  exec_func = model._make_execution_function(mode)
47  return (exec_func.inputs, exec_func.outputs, exec_func.updates_op,
48          exec_func.session_kwargs)
49
50
51def _build_model(strategy, model, mode, inputs, targets=None):
52  if model._compile_distribution:
53    dist_utils.clone_model_on_replicas(
54        model, strategy, mode, inputs=inputs, targets=targets)
55  else:
56    dist_utils._build_distributed_network(model, strategy, mode, inputs,
57                                          targets)
58
59
60def _make_train_step_fn(model, mode, strategy, output_labels):
61  """Create step fn.
62
63  Args:
64    model: a Keras Model instance.
65    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
66    strategy: a `tf.distribute.Strategy` instance.
67    output_labels: the output labels for the step function.
68
69  Returns:
70    A step function to run by `tf.distribute.Strategy`.
71  """
72
73  def _step_fn(ctx, inputs):
74    """A step fn that returns update ops."""
75    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
76      inputs, targets = inputs
77    else:
78      targets = None
79
80    # When input feature is a dictionary of tensors, dictionary is flattended
81    # to an array and passed as a model input. This results in input mismatch
82    # when model input layer names are not sorted in alphabetical order as
83    # `nest.flatten()`sorts dictionary elements by keys. As so, transform input
84    # tensors into an array and order it along `model._feed_input_names`.
85    if isinstance(inputs, dict):
86      inputs = [inputs[input_name] for input_name in model._feed_input_names]
87
88    _build_model(strategy, model, mode, inputs, targets)
89
90    (grouped_inputs, grouped_outputs, grouped_updates,
91     grouped_session_args) = strategy.extended.call_for_each_replica(
92         _per_replica_execution_function,
93         args=(dist_utils.get_distributed_model(model, mode), mode))
94    (all_inputs, all_outputs, all_updates,
95     all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs,
96                                                  grouped_outputs,
97                                                  grouped_updates,
98                                                  grouped_session_args)
99    combined_fn = K.function(
100        all_inputs,
101        all_outputs,
102        updates=all_updates,
103        name='distributed_' + str(mode) + '_function',
104        **all_session_args)
105
106    for label, output in zip(output_labels, combined_fn.outputs):
107      if label == 'loss':
108        reduce_op = ds_reduce_util.ReduceOp.SUM
109      else:
110        # We reduce all other metrics using mean for now. This is temporary
111        # workaround until new metrics are in place.
112        reduce_op = ds_reduce_util.ReduceOp.MEAN
113      ctx.set_last_step_output(label, output, reduce_op)
114
115    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
116    # feed_dict, session kwargs, run options, run_metadata for now. These should
117    # be handled appropriately
118    return combined_fn.updates_op
119
120  return _step_fn
121
122
123def experimental_tpu_fit_loop(model,
124                              dataset,
125                              epochs=100,
126                              verbose=1,
127                              callbacks=None,
128                              initial_epoch=0,
129                              steps_per_epoch=None,
130                              val_dataset=None,
131                              validation_steps=None,
132                              validation_freq=1):
133  """Fit loop for training with TPU tf.distribute.Strategy.
134
135  Args:
136      model: Keras Model instance.
137      dataset: Dataset that returns inputs and targets
138      epochs: Number of times to iterate over the data
139      verbose: Integer, Verbosity mode, 0, 1 or 2
140      callbacks: List of callbacks to be called during training
141      initial_epoch: Epoch at which to start training
142          (useful for resuming a previous training run)
143      steps_per_epoch: Total number of steps (batches of samples)
144          before declaring one epoch finished and starting the
145          next epoch. Ignored with the default value of `None`.
146      val_dataset: Dataset for validation data.
147      validation_steps: Number of steps to run validation for
148          (only if doing validation from data tensors).
149          Ignored with the default value of `None`.
150      validation_freq: Only relevant if validation data is provided. Integer or
151          `collections.abc.Container` instance (e.g. list, tuple, etc.). If an
152          integer, specifies how many training epochs to run before a new
153          validation run is performed, e.g. `validation_freq=2` runs
154          validation every 2 epochs. If a Container, specifies the epochs on
155          which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
156          validation at the end of the 1st, 2nd, and 10th epochs.
157
158  Returns:
159      Returns `None`.
160
161  Raises:
162      ValueError: in case of invalid arguments.
163  """
164  mode = ModeKeys.TRAIN
165
166  current_strategy = model._distribution_strategy
167  iteration_value = min(steps_per_epoch,
168                        current_strategy.extended.steps_per_run)
169  steps_per_run = K.variable(
170      value=iteration_value,
171      dtype='int32',
172      name='steps_per_run')
173
174  # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
175  iterator = dist_utils.get_iterator(dataset, current_strategy)
176
177  scope = dist_utils.distributed_scope(
178      strategy=current_strategy, learning_phase=1)
179  scope.__enter__()
180
181  out_labels = model.metrics_names or []
182
183  step_fn = _make_train_step_fn(model, ModeKeys.TRAIN, current_strategy,
184                                out_labels)
185
186  # Add initial dummy values for loss and other metric tensors.
187  initial_loop_values = {}
188  initial_loop_values['loss'] = constant_op.constant(1e7)
189  for m in model._get_training_eval_metrics():
190    tensor = m.result()
191    initial_loop_values[m.name] = array_ops.zeros(tensor.shape, tensor.dtype)
192
193  ctx = current_strategy.extended.experimental_run_steps_on_iterator(
194      step_fn, iterator, iterations=steps_per_run,
195      initial_loop_values=initial_loop_values)
196  train_op = ctx.run_op
197  output_tensors = ctx.last_step_outputs
198
199  do_validation = bool(validation_steps)
200
201  if model._compile_distribution:
202    dist_utils._copy_weights_to_distributed_model(model, mode)
203
204  callbacks = cbks.configure_callbacks(
205      callbacks,
206      model,
207      do_validation=do_validation,
208      epochs=epochs,
209      steps_per_epoch=steps_per_epoch,
210      verbose=verbose,
211      count_mode='steps',
212      mode=mode)
213
214  # Calculate the steps each time on the device.
215  steps_to_run = ([current_strategy.extended.steps_per_run] *
216                  (steps_per_epoch //
217                   current_strategy.extended.steps_per_run))
218  if steps_per_epoch % current_strategy.extended.steps_per_run:
219    steps_to_run.append(
220        steps_per_epoch % current_strategy.extended.steps_per_run)
221  target_steps = len(steps_to_run)
222
223  callbacks._call_begin_hook(mode)
224
225  initial_epoch = model._maybe_load_initial_epoch_from_ckpt(initial_epoch, mode)
226
227  for epoch in range(initial_epoch, epochs):
228    dist_utils._reset_metrics(model)
229    callbacks.on_epoch_begin(epoch)
230    epoch_logs = {}
231    step_index = 0
232    prev_step_count = None
233    current_step = 0
234    while current_step < target_steps:
235      step_count = steps_to_run[current_step]
236      batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
237      callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs)
238      if prev_step_count is None or step_count != prev_step_count:
239        K.get_session().run(steps_per_run.assign(step_count))
240        prev_step_count = step_count
241      try:
242        _, outputs = K.batch_get_value([train_op, output_tensors])
243      except errors.OutOfRangeError:
244        logging.warning('Your dataset iterator ran out of data; '
245                        'interrupting training. Make sure that your dataset '
246                        'can generate at least `steps_per_epoch * epochs` '
247                        'batches (in this case, %d batches).' %
248                        steps_per_epoch * epochs)
249        break
250
251      batch_logs.update(outputs)
252      callbacks._call_batch_hook(mode, 'end', step_index, batch_logs)
253      step_index = step_index + step_count
254      current_step += 1
255
256      if callbacks.model.stop_training:
257        break
258
259    if (do_validation and
260        training_utils_v1.should_run_validation(validation_freq, epoch)):
261      logging.info('Running validation at fit epoch: %s', epoch)
262
263      if model._compile_distribution:
264        # Since we create a new clone from the original model we need to copy
265        # the weights back to the original model before we can run validation.
266        dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
267
268      val_outs = experimental_tpu_test_loop(  # pylint: disable=undefined-variable
269          model,
270          val_dataset,
271          steps=validation_steps,
272          verbose=verbose,
273          callbacks=callbacks)
274      if not isinstance(val_outs, list):
275        val_outs = [val_outs]
276      # Same labels assumed.
277      for label, val_out in zip(out_labels, val_outs):
278        epoch_logs['val_' + label] = val_out
279
280    callbacks.on_epoch_end(epoch, epoch_logs)
281    if callbacks.model.stop_training:
282      break
283  model._successful_loop_finish = True
284  callbacks._call_end_hook(mode)
285
286  if model._compile_distribution:
287    # Copy the weights back from the replicated model to the original model.
288    dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
289  scope.__exit__(None, None, None)
290  return model.history
291
292
293def experimental_tpu_test_loop(model,
294                               dataset,
295                               verbose=0,
296                               steps=None,
297                               callbacks=None):
298  """Test loop for evaluating with TPU tf.distribute.Strategy.
299
300  Args:
301      model: Keras Model instance.
302      dataset: Dataset for input data.
303      verbose: Integer, Verbosity mode 0 or 1.
304      steps: Total number of steps (batches of samples)
305          before declaring predictions finished.
306          Ignored with the default value of `None`.
307      callbacks: List of callbacks to be called during training
308
309  Returns:
310      Scalar loss (if the model has a single output and no metrics)
311      or list of scalars (if the model has multiple outputs
312      and/or metrics). The attribute `model.metrics_names` will give you
313      the display labels for the outputs.
314  """
315  mode = ModeKeys.TEST
316  current_strategy = model._distribution_strategy
317  iterator = dist_utils.get_iterator(dataset, current_strategy)
318
319  scope = dist_utils.distributed_scope(
320      strategy=current_strategy, learning_phase=0)
321  scope.__enter__()
322
323  out_labels = model.metrics_names
324
325  def _test_step_fn(inputs):
326    """A fn that returns output of single test step."""
327    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
328      inputs, targets = inputs
329    else:
330      targets = None
331
332    (distribution_strategy_context.get_replica_context().merge_call(
333        _build_model, args=(model, mode, inputs, targets)))
334
335    (_, outputs, updates, _) = _per_replica_execution_function(
336        dist_utils.get_distributed_model(model, mode), mode)
337    with ops.control_dependencies([updates]):
338      return [array_ops.identity(out) for out in outputs]
339
340  test_input_data = iterator.get_next()
341  per_replica_outputs = current_strategy.run(
342      _test_step_fn, args=(test_input_data,))
343  output_tensors = {}
344  for label, output in zip(out_labels, per_replica_outputs):
345    if label == 'loss':
346      reduce_op = ds_reduce_util.ReduceOp.SUM
347    else:
348      # We reduce all other metrics using mean for now. This is temporary
349      # workaround until new metrics are in place.
350      reduce_op = ds_reduce_util.ReduceOp.MEAN
351    output_tensors[label] = current_strategy.reduce(reduce_op, output,
352                                                    axis=None)
353  test_op = control_flow_ops.group(list(output_tensors.values()))
354
355  if verbose >= 1:
356    progbar = Progbar(target=steps)
357
358  if model._compile_distribution:
359    dist_utils._copy_weights_to_distributed_model(model, mode)
360
361  dist_utils._reset_metrics(model)
362
363  callbacks = cbks.configure_callbacks(
364      callbacks,
365      model,
366      do_validation=False,
367      epochs=1,
368      steps_per_epoch=steps,
369      verbose=verbose,
370      count_mode='steps',
371      mode=ModeKeys.TEST)
372  callbacks._call_begin_hook(mode)
373
374  outs = [0.] * len(model.metrics_names)
375  if steps is not None:
376    target_steps = steps
377  else:
378    raise ValueError('Number of steps could not be inferred from the data, '
379                     'please pass the steps argument.')
380
381  current_step = 0
382  while current_step < target_steps:
383    batch_logs = {'batch': current_step, 'size': 1}
384    callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
385    try:
386      _, batch_outs = K.batch_get_value([test_op, output_tensors])
387    except errors.OutOfRangeError:
388      warning_msg = (
389          'Make sure that your dataset can generate at least '
390          '`steps` batches (in this case, {} batches).'.format(steps))
391
392      logging.warning('Your dataset iterator ran out of data; '
393                      'interrupting evaluation. ' + warning_msg)
394      target_steps = current_step
395      break
396    for i, label in enumerate(model.metrics_names):
397      if i == 0:
398        # Loss is stateless metrics.
399        outs[i] += batch_outs[label]
400      else:
401        # For all stateful metrics, the aggregation is handled by mirrored vars.
402        outs[i] = batch_outs[label]
403
404    batch_logs = cbks.make_logs(model, batch_logs, outs, mode)
405    callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
406    if verbose == 1:
407      progbar.update(current_step + 1)
408    current_step += 1
409
410  if verbose >= 1:
411    # Progress bar finishes at the end.
412    progbar.update(target_steps)
413  callbacks._call_end_hook(mode)
414
415  scope.__exit__(None, None, None)
416  if len(outs) >= 0:
417    outs[0] /= (target_steps)
418
419  if len(outs) == 1:
420    return outs[0]
421  return outs
422
423
424def experimental_tpu_predict_loop(model,
425                                  dataset,
426                                  verbose=0,
427                                  steps=None,
428                                  callbacks=None):
429  """Predict loop for predicting with TPU tf.distribute.Strategy.
430
431  Args:
432      model: Keras Model instance.
433      dataset: Dataset for input data.
434      verbose: Integer, Verbosity mode 0 or 1.
435      steps: Total number of steps (batches of samples)
436          before declaring `_predict_loop` finished.
437          Ignored with the default value of `None`.
438      callbacks: List of callbacks to be called during training
439
440  Returns:
441      Array of predictions (if the model has a single output)
442      or list of arrays of predictions
443      (if the model has multiple outputs).
444  """
445  mode = ModeKeys.PREDICT
446  dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
447  padding_handler = None
448  if not dataset_fully_shaped:
449    # TODO(hongjunchoi): Investigate whether operations from
450    # PartialBatchPaddingHandler are unnecessarily pruned out
451    # during graph optimization.
452    padding_handler = padding_util.PartialBatchPaddingHandler(
453        model._feed_output_shapes)
454    batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset)
455    padding_handler.padded_batch_size = batch_size
456    padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
457                                                  padding_handler.update_mask)
458
459    dataset = dataset.map(padding_handler.pad_batch)
460    dataset = dataset.unbatch()
461    # Upon this point, it is guaranteed that the dataset does not
462    # have partial batches. Thus, we set `drop_remainder=True` to
463    # get static shape information about the elements in the dataset.
464    dataset = dataset.batch(batch_size, drop_remainder=True)
465
466    if prefetch_buffer is not None:
467      dataset = dataset.prefetch(prefetch_buffer)
468
469  current_strategy = model._distribution_strategy
470  iterator = dist_utils.get_iterator(dataset, current_strategy)
471
472  scope = dist_utils.distributed_scope(
473      strategy=current_strategy, learning_phase=0)
474  scope.__enter__()
475
476  def _predict_step_fn(inputs):
477    """A fn that returns output of single prediction step."""
478
479    (distribution_strategy_context.get_replica_context().merge_call(
480        _build_model, args=(model, mode, inputs)))
481
482    (_, outputs, updates, _) = _per_replica_execution_function(
483        dist_utils.get_distributed_model(model, mode), mode)
484
485    with ops.control_dependencies([updates]):
486      return [array_ops.identity(out) for out in outputs]
487
488  # TODO(hongjunchoi): When numpy array is passed as an input to `predict()`
489  # use numpy arrays directly to avoid cumulating unnecessary input pipeline
490  # ops.
491  predict_input_data = iterator.get_next()
492  per_replica_outputs = current_strategy.run(
493      _predict_step_fn, args=(predict_input_data,))
494  output_tensors = dist_utils.flatten_per_replica_values(
495      current_strategy, per_replica_outputs)
496
497  if verbose >= 1:
498    progbar = Progbar(target=steps)
499
500  if model._compile_distribution:
501    dist_utils._copy_weights_to_distributed_model(model, mode)
502
503  dist_utils._reset_metrics(model)
504
505  callbacks = cbks.configure_callbacks(
506      callbacks,
507      model,
508      do_validation=False,
509      epochs=1,
510      steps_per_epoch=steps,
511      verbose=verbose,
512      count_mode='steps',
513      mode=mode)
514  callbacks._call_begin_hook(mode)
515
516  # Since we do not know how many samples we will see, we cannot pre-allocate
517  # the returned Numpy arrays. Instead, we store one array per batch seen
518  # and concatenate them upon returning.
519  num_model_outputs = len(model.output_names)
520  unconcatenated_outs = [[] for _ in range(num_model_outputs)]
521  if steps is not None:
522    target_steps = steps
523  else:
524    raise ValueError('Number of steps could not be inferred from the data, '
525                     'please pass the steps argument.')
526
527  current_step = 0
528  while current_step < target_steps:
529    batch_logs = {'batch': current_step, 'size': 1}
530    callbacks._call_batch_hook(mode, 'begin', current_step, batch_logs)
531    try:
532      predict_ops = control_flow_ops.group(output_tensors)
533      _, batch_outs = K.batch_get_value([predict_ops, output_tensors])
534
535    except errors.OutOfRangeError:
536      warning_msg = (
537          'Make sure that your dataset can generate at least '
538          '`steps` batches (in this case, {} batches).'.format(steps))
539
540      logging.warning('Your dataset iterator ran out of data; '
541                      'interrupting evaluation. ' + warning_msg)
542      break
543
544    # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
545    for i in range(num_model_outputs):
546      output_start_index = i * current_strategy.num_replicas_in_sync
547      output_end_index = (
548          output_start_index + current_strategy.num_replicas_in_sync)
549      single_model_output = batch_outs[output_start_index:output_end_index]
550      unconcatenated_outs[i].extend(single_model_output)
551
552    batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
553    callbacks._call_batch_hook(mode, 'end', current_step, batch_logs)
554    if verbose == 1:
555      progbar.update(current_step + 1)
556    current_step += 1
557
558  if verbose >= 1:
559    # Progress bar finishes at the end.
560    progbar.update(current_step)
561
562  callbacks._call_end_hook(mode)
563
564  scope.__exit__(None, None, None)
565
566  if len(unconcatenated_outs) == 1:
567    prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
568  else:
569    prediction_result = [
570        np.concatenate(out, axis=0) for out in unconcatenated_outs
571    ]
572
573  if padding_handler:
574    prediction_result = padding_handler.apply_mask(prediction_result)
575
576  return prediction_result
577
578
579class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
580  """Training loop for distribution strategy with single worker."""
581
582  def fit(self,
583          model,
584          x=None,
585          y=None,
586          batch_size=None,
587          epochs=1,
588          verbose=1,
589          callbacks=None,
590          validation_split=0.,
591          validation_data=None,
592          shuffle=True,
593          class_weight=None,
594          sample_weight=None,
595          initial_epoch=0,
596          steps_per_epoch=None,
597          validation_steps=None,
598          validation_freq=1,
599          **kwargs):
600    """Fit loop for Distribution Strategies."""
601    dist_utils.validate_callbacks(input_callbacks=callbacks,
602                                  optimizer=model.optimizer)
603    dist_utils.validate_inputs(x, y)
604
605    batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
606        model._distribution_strategy,
607        x,
608        batch_size,
609        steps_per_epoch,
610        ModeKeys.TRAIN,
611        validation_split=validation_split)
612    batch_size = model._validate_or_infer_batch_size(
613        batch_size, steps_per_epoch, x)
614    dataset = model._distribution_standardize_user_data(
615        x, y,
616        sample_weight=sample_weight,
617        class_weight=class_weight,
618        batch_size=batch_size,
619        validation_split=validation_split,
620        shuffle=shuffle,
621        epochs=epochs)
622    if not dist_utils.is_distributing_by_cloning(model):
623      with model._distribution_strategy.scope():
624        (dataset, _, _) = model._standardize_user_data(
625            dataset,
626            sample_weight=sample_weight,
627            class_weight=class_weight,
628            batch_size=batch_size,
629            validation_split=validation_split,
630            shuffle=shuffle)
631
632    val_dataset = None
633    if validation_data:
634      val_x, val_y, val_sample_weights = (
635          training_utils_v1.unpack_validation_data(validation_data))
636      dist_utils.validate_inputs(val_x, val_y)
637      _, validation_steps = dist_utils.process_batch_and_step_size(
638          model._distribution_strategy, val_x, batch_size, validation_steps,
639          ModeKeys.TEST)
640
641      val_dataset = model._distribution_standardize_user_data(
642          val_x, val_y,
643          sample_weight=val_sample_weights,
644          class_weight=None,
645          batch_size=batch_size,
646          validation_split=validation_split,
647          shuffle=shuffle,
648          allow_partial_batch=True)
649    elif validation_split:
650      raise ValueError('validation_split argument is not supported with '
651                       'distribution strategies.')
652
653    if K.is_tpu_strategy(model._distribution_strategy):
654      steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
655          model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
656      if steps_per_epoch is None:
657        raise ValueError('Number of steps could not be inferred from the data, '
658                         'please pass the steps_per_epoch argument.')
659
660      if not context.executing_eagerly():
661        # Run TPU training in a custom loop in graph mode.
662        return experimental_tpu_fit_loop(
663            model,
664            dataset,
665            epochs=epochs,
666            verbose=verbose,
667            callbacks=callbacks,
668            val_dataset=val_dataset,
669            initial_epoch=initial_epoch,
670            steps_per_epoch=steps_per_epoch,
671            validation_steps=validation_steps,
672            validation_freq=validation_freq)
673
674    return training_arrays_v1.fit_loop(
675        model,
676        dataset,
677        batch_size=batch_size,
678        epochs=epochs,
679        verbose=verbose,
680        callbacks=callbacks,
681        val_inputs=val_dataset,
682        shuffle=shuffle,
683        initial_epoch=initial_epoch,
684        steps_per_epoch=steps_per_epoch,
685        validation_steps=validation_steps,
686        validation_freq=validation_freq,
687        steps_name='steps_per_epoch')
688
689  def evaluate(self,
690               model,
691               x=None,
692               y=None,
693               batch_size=None,
694               verbose=1,
695               sample_weight=None,
696               steps=None,
697               callbacks=None,
698               **kwargs):
699    """Evaluate loop for Distribution Strategies."""
700    dist_utils.validate_inputs(x, y)
701    batch_size, steps = dist_utils.process_batch_and_step_size(
702        model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST)
703    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
704    dataset = model._distribution_standardize_user_data(
705        x, y,
706        sample_weight=sample_weight,
707        batch_size=batch_size,
708        allow_partial_batch=True)
709
710    if K.is_tpu_strategy(model._distribution_strategy):
711      steps = training_utils_v1.infer_steps_for_dataset(
712          model, dataset, steps, steps_name='steps')
713      if steps is None:
714        raise ValueError('Number of steps could not be inferred from the data, '
715                         'please pass the steps argument.')
716
717      if not context.executing_eagerly():
718        # Run TPU evaluation in a custom loop in graph mode.
719        return experimental_tpu_test_loop(
720            model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
721
722    return training_arrays_v1.test_loop(
723        model,
724        inputs=dataset,
725        batch_size=batch_size,
726        verbose=verbose,
727        steps=steps,
728        callbacks=callbacks)
729
730  def predict(self,
731              model,
732              x,
733              batch_size=None,
734              verbose=0,
735              steps=None,
736              callbacks=None,
737              **kwargs):
738    """Predict loop for Distribution Strategies."""
739    dist_utils.validate_inputs(x=x, y=None)
740    batch_size, steps = dist_utils.process_batch_and_step_size(
741        model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT)
742    batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
743    dataset = model._distribution_standardize_user_data(
744        x,
745        batch_size=batch_size,
746        allow_partial_batch=True)
747    if K.is_tpu_strategy(model._distribution_strategy):
748      steps = training_utils_v1.infer_steps_for_dataset(
749          model, dataset, steps, steps_name='steps')
750      if steps is None:
751        raise ValueError('Number of steps could not be inferred from the data, '
752                         'please pass the steps argument.')
753      if not context.executing_eagerly():
754        return experimental_tpu_predict_loop(
755            model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
756    return training_arrays_v1.predict_loop(
757        model,
758        dataset,
759        batch_size=batch_size,
760        verbose=verbose,
761        steps=steps,
762        callbacks=callbacks)
763
764
765def _train_with_multi_worker(method):
766  """Decorator that handles multi worker training with distribution strategy."""
767
768  def wrapper(model, **kwargs):
769    def _worker_fn(_):
770      callbacks = kwargs.pop('callbacks', None)
771      filtered_callbacks = dist_utils.filter_distributed_callbacks(
772          callbacks, model)
773      kwargs['callbacks'] = filtered_callbacks
774      return method(model, **kwargs)
775
776    return dc.run_distribute_coordinator(
777        _worker_fn,
778        model._distribution_strategy,
779        mode='independent_worker')
780
781  return wrapper
782
783
784class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
785  """Training loop for distribution strategy with multiple worker."""
786
787  def __init__(self, single_worker_loop):
788    self._single_worker_loop = single_worker_loop
789
790  def fit(self, *args, **kwargs):
791    return _train_with_multi_worker(self._single_worker_loop.fit)(
792        *args, **kwargs)
793
794  def evaluate(self, *args, **kwargs):
795    return _train_with_multi_worker(self._single_worker_loop.evaluate)(
796        *args, **kwargs)
797
798  def predict(self, *args, **kwargs):
799    # Currently predict is still using the single worker implementation.
800    return self._single_worker_loop.predict(*args, **kwargs)
801