• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
16# pylint:disable=protected-access
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import iterator_ops
25from tensorflow.python.distribute import distribute_coordinator_context as dc_context
26from tensorflow.python.distribute import reduce_util
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.keras import backend as K
32from tensorflow.python.keras import callbacks
33from tensorflow.python.keras import metrics as metrics_module
34from tensorflow.python.keras import optimizers
35from tensorflow.python.keras.engine import training_utils
36from tensorflow.python.keras.optimizer_v2 import optimizer_v2
37from tensorflow.python.keras.utils.mode_keys import ModeKeys
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.util import nest
43from tensorflow.python.util import tf_contextlib
44
45
46def set_weights(distribution_strategy, dist_model, weights):
47  """Sets the weights of the replicated models.
48
49  The weights of the replicated models are set to the weights of the original
50  model. The weights of the replicated model are Mirrored variables and hence
51  we need to use the `update` call within a DistributionStrategy scope.
52
53  Args:
54    distribution_strategy: DistributionStrategy used to distribute training
55        and validation.
56    dist_model: The replicated models on the different devices.
57    weights: The weights of the original model.
58  """
59  assign_ops = []
60  for layer in dist_model.layers:
61    num_param = len(layer.weights)
62    layer_weights = weights[:num_param]
63    for sw, w in zip(layer.weights, layer_weights):
64      if ops.executing_eagerly_outside_functions():
65        sw.assign(w)
66      else:
67        assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
68    weights = weights[num_param:]
69
70  if not ops.executing_eagerly_outside_functions():
71    K.get_session(assign_ops).run(assign_ops)
72
73
74def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
75                  grouped_updates=None, grouped_session_args=None,
76                  with_loss_tensor=False):
77  """Unwrap and return the list of values contained in the PerDevice parameters.
78
79  This function calls `flatten_perdevice_values` to parse each of the input
80  parameters into a list of values on the different devices. If we set
81  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
82  the different devices to give us one loss tensor.
83
84  Args:
85    distribution_strategy: DistributionStrategy used to distribute training and
86        validation.
87    grouped_inputs: PerDevice inputs returned from the train or test function
88        that we ran on each device.
89    grouped_outputs: PerDevice outputs returned from the train or test function
90        that we ran on each device.
91    grouped_updates: PerDevice updates returned from the train or test function
92        that we ran on each device.
93    grouped_session_args: PerDevice session args returned from the train or
94        test function that we ran on each device.
95    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
96        tensor as one of the outputs.
97
98  Returns:
99    Values of each of the PerDevice parameters.
100
101  """
102  # Unwrap per device values returned from each model's train function.
103  # This will be used to construct the main train function.
104  all_inputs = flatten_perdevice_values(distribution_strategy,
105                                        grouped_inputs)
106  if with_loss_tensor:
107    # reduce loss tensor before adding it to the list of fetches
108    loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
109                                        grouped_outputs[0])
110    all_outputs = flatten_perdevice_values(distribution_strategy,
111                                           grouped_outputs[1:])
112    all_outputs = [loss] + all_outputs
113  else:
114    all_outputs = flatten_perdevice_values(distribution_strategy,
115                                           grouped_outputs)
116
117  if grouped_updates:
118    all_updates = flatten_perdevice_values(distribution_strategy,
119                                           grouped_updates)
120  else:
121    all_updates = None
122
123  all_session_args = {}
124  if grouped_session_args:
125    grouped_feed_dict = grouped_session_args.get('feed_dict')
126    if grouped_feed_dict:
127      all_session_args['feed_dict'] = flatten_perdevice_values(
128          distribution_strategy, grouped_feed_dict)
129
130    grouped_fetches = grouped_session_args.get('fetches')
131    if grouped_fetches:
132      all_session_args['fetches'] = flatten_perdevice_values(
133          distribution_strategy, grouped_fetches)
134
135  # TODO(priyag): Return only non empty/None values
136  return all_inputs, all_outputs, all_updates, all_session_args
137
138
139def flatten_perdevice_values(distribution_strategy, perdevice_values):
140  """Unwraps and flattens a nest of PerDevice parameters.
141
142  PerDevice values have one value associated with each device. Each entry in
143  the PerDevice dict has a device `key` and the corresponding value on the
144  device as the `value`. In this function we take a PerDevice value or a list of
145  PerDevice values and return all the values in the PerDevice dict.
146
147  Args:
148    distribution_strategy: DistributionStrategy used to distribute training and
149        validation.
150    perdevice_values: List of PerDevice object or a single PerDevice object.
151
152  Returns:
153    List of values of all the PerDevice objects.
154
155  """
156  # This function takes a PerDevice object or a list of PerDevice objects and
157  # returns all the values associated with it.
158  return [e for flattened in nest.flatten(perdevice_values)
159          for e in distribution_strategy.unwrap(flattened)]
160
161
162def validate_callbacks(input_callbacks, optimizer):
163  """Validate whether given callbacks are supported by DistributionStrategy.
164
165  Args:
166    input_callbacks: List of callbacks passed by the user to fit.
167    optimizer: Optimizer instance used to train the model.
168
169  Raises:
170    ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
171        callbacks passed.
172    ValueError: If `histogram_freq` or `write_grads` is one of the parameters
173        passed as part of the TensorBoard callback.
174  """
175  if input_callbacks:
176    for callback in input_callbacks:
177      if callback not in [callbacks.TensorBoard, callbacks.ReduceLROnPlateau,
178                          callbacks.LearningRateScheduler, callbacks.CSVLogger,
179                          callbacks.EarlyStopping, callbacks.ModelCheckpoint,
180                          callbacks.TerminateOnNaN, callbacks.ProgbarLogger,
181                          callbacks.History, callbacks.RemoteMonitor]:
182        logging.warning('Your input callback is not one of the predefined '
183                        'Callbacks that supports DistributionStrategy. You '
184                        'might encounter an error if you access one of the '
185                        'model\'s attributes as part of the callback since '
186                        'these attributes are not set. You can access each of '
187                        'the individual distributed models using the '
188                        '`_grouped_model` attribute of your original model.')
189      if isinstance(callback, (callbacks.LearningRateScheduler,
190                               callbacks.ReduceLROnPlateau)):
191
192        if not isinstance(optimizer, optimizer_v2.OptimizerV2):
193          raise ValueError('You must specify a Keras Optimizer V2 when using '
194                           '%s callback with DistributionStrategy.' % callback)
195
196      # If users want to use the TensorBoard callback they cannot use certain
197      # features of the callback that involve accessing model attributes and
198      # running ops.
199      if isinstance(callback, callbacks.TensorBoard):
200        if getattr(callback, 'histogram_freq', False):
201          logging.warning(
202              UserWarning(
203                  '`histogram_freq` in the TensorBoard callback is not '
204                  'supported when using DistributionStrategy. Setting '
205                  '`histogram_freq` to `0`.'))
206          callback.histogram_freq = 0
207        if getattr(callback, 'write_grads', False):
208          logging.warning(
209              UserWarning(
210                  '`write_grads` in the TensorBoard callback is not supported '
211                  'when using DistributionStrategy. Setting `write_grads` '
212                  'to `False`.'))
213          callback.histogram_freq = False
214
215
216def validate_distributed_dataset_inputs(distribution_strategy, x, y,
217                                        sample_weights=None):
218  """Validate all the components of a DistributedValue Dataset input.
219
220  Args:
221    distribution_strategy: The current DistributionStrategy used to call
222        `fit`/`evaluate`.
223    x: Input Dataset DistributedValue object. For example, when we use
224        `MirroredStrategy` this is a PerDevice object with a tensor for each
225        device set in the dict. x can also be a tuple or dict. The keys of the
226        dict should match the names of the input layers of the model.
227    y: Target Dataset DistributedValue object. For example, when we use
228        `MirroredStrategy` this is a PerDevice object with a tensor for each
229        device set in the dict. y can also be a tuple or dict. The keys of the
230        dict should match the names of the output layers of the model.
231    sample_weights: Sample weights Dataset DistributedValue object. For example,
232        when we use `MirroredStrategy` this is a PerDevice object with a tensor
233        for each device set in the dict.
234
235  Returns:
236    The unwrapped values list of the x and y DistributedValues inputs.
237
238  Raises:
239    ValueError: If x and y do not have support for being evaluated as tensors.
240        or if x and y contain elements that are not tensors or if x and y
241        contain elements that have a shape or dtype mismatch.
242  """
243  # If the input and target used to call the model are not dataset tensors,
244  # we need to raise an error. When using a DistributionStrategy, the input
245  # and targets to a model should be from a `tf.data.Dataset`.
246
247  # If each element of x and y are not tensors, we cannot standardize and
248  # validate the input and targets.
249  x_values_list = validate_per_device_inputs(distribution_strategy, x)
250
251  if y is not None:
252    y_values_list = validate_per_device_inputs(distribution_strategy, y)
253  else:
254    y_values_list = None
255
256  if sample_weights is not None:
257    sample_weights_list = validate_per_device_inputs(distribution_strategy,
258                                                     sample_weights)
259  else:
260    sample_weights_list = None
261
262  # Return the unwrapped values to avoid calling `unwrap` a second time.
263  return x_values_list, y_values_list, sample_weights_list
264
265
266def validate_per_device_inputs(distribution_strategy, x):
267  """Validates PerDevice dataset input list.
268
269  Args:
270    distribution_strategy: The current DistributionStrategy used to call
271      `fit`, `evaluate` and `predict`.
272    x: A list of PerDevice objects that represent the input or
273      target values.
274
275  Returns:
276    List containing the first element of each of the PerDevice objects in
277    the input list.
278
279  Raises:
280    ValueError: If any of the objects in the `per_device_list` is not a tensor.
281
282  """
283  # Convert the inputs and targets into a list of PerDevice objects.
284  per_device_list = nest.flatten(x)
285  x_values_list = []
286  for x in per_device_list:
287    if not tensor_util.is_tensor(x):
288      raise ValueError('Dataset input to the model should be tensors instead '
289                       'they are of type {}'.format(type(x)))
290
291    # At this point both x and y contain tensors in the `DistributedValues`
292    # structure.
293    x_values = distribution_strategy.unwrap(x)
294
295    # Validate that the shape and dtype of all the elements in x are the same.
296    validate_all_tensor_shapes(x, x_values)
297    validate_all_tensor_types(x, x_values)
298
299    x_values_list.append(x_values[0])
300  return x_values_list
301
302
303def validate_all_tensor_types(x, x_values):
304  x_dtype = x_values[0].dtype
305  for i in range(1, len(x_values)):
306    if x_dtype != x_values[i].dtype:
307      raise ValueError('Input tensor dtypes do not match for distributed tensor'
308                       ' inputs {}'.format(x))
309
310
311def validate_all_tensor_shapes(x, x_values):
312  # Validate that the shape of all the elements in x have the same shape
313  x_shape = x_values[0].get_shape().as_list()
314  for i in range(1, len(x_values)):
315    if x_shape != x_values[i].get_shape().as_list():
316      raise ValueError('Input tensor shapes do not match for distributed tensor'
317                       ' inputs {}'.format(x))
318
319
320def _wait_for_variable_initialization(session):
321  """Utility to wait for variables to be initialized."""
322  all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
323  candidate_vars = []
324  for v in all_variables:
325    if not getattr(v, '_keras_initialized', False):
326      candidate_vars.append(v)
327
328  if not candidate_vars:
329    return
330
331  while True:
332    is_initialized = session.run(
333        [variables.is_variable_initialized(v) for v in candidate_vars])
334    uninitialized_vars = []
335    for flag, v in zip(is_initialized, candidate_vars):
336      if not flag:
337        uninitialized_vars.append(v)
338      v._keras_initialized = True  # pylint: disable=protected-access
339    if not uninitialized_vars:
340      break
341
342
343def init_restore_or_wait_for_variables():
344  """Initialize or restore variables or wait for variables to be initialized."""
345  session = K._get_session()  # pylint: disable=protected-access
346  worker_context = dc_context.get_current_worker_context()
347  if not worker_context or worker_context.experimental_should_init:
348    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
349    K._initialize_variables(session)  # pylint: disable=protected-access
350  else:
351    _wait_for_variable_initialization(session)
352
353
354def validate_inputs(x, y, distribution_strategy, allow_partial_batch=False):
355  """Validate inputs when using DistributionStrategy.
356
357  Args:
358    x: Model Inputs.
359    y: Model Targets.
360    distribution_strategy: The DistributionStrategy with which the model is
361      compiled.
362    allow_partial_batch: Boolean. If false, datasets must have fully
363      defined shapes.
364
365  Raises:
366    ValueError: if input is not a Dataset or a numpy array(when we use
367      MirroredStrategy).
368  """
369  if (isinstance(x, iterator_ops.Iterator) or
370      isinstance(y, iterator_ops.Iterator)):
371    raise ValueError('`DistributionStrategy` does not support inputs of type '
372                     'Iterator. You must pass a `tf.data.Dataset` object or a '
373                     'numpy array as input.')
374
375  if is_tpu_strategy(distribution_strategy):
376    for i in [x, y]:
377      if (isinstance(i, dataset_ops.DatasetV2) and not allow_partial_batch):
378        if not is_dataset_shape_fully_defined(i):
379          raise ValueError(
380              'Using TPUs currently requires fully defined shapes. Either use '
381              'set_shape() on the input tensors or use '
382              'dataset.batch(..., drop_remainder=True).'
383              'Found unknown shape in input {}.'.format(i))
384
385
386# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
387# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
388# no longer needed.
389def global_batch_size_supported(distribution_strategy):
390  return distribution_strategy.extended._global_batch_size  # pylint: disable=protected-access
391
392
393# TODO(sourabhbajaj): Remove this once we use the same API for all strategies.
394def is_tpu_strategy(strategy):
395  """We're executing TPU Strategy."""
396  return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
397
398
399def is_dataset_shape_fully_defined(dataset):
400  """Returns whether a dataset contains a final partial batch."""
401  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
402  unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
403  return not unknown_shapes
404
405
406def get_input_params(distribution_strategy, first_x_value, steps, batch_size,
407                     mode=None):
408  """Calculate the number of batches and steps/steps_per_epoch.
409
410  Args:
411    distribution_strategy: The DistributionStrategy used to compile the model.
412    first_x_value: This is the first input numpy array that is passed in as the
413      model input.
414    steps:  The specified number of steps.
415    batch_size: The specified batch_size.
416    mode: ModeKey representing whether input will be used for training,
417      evaluation, or prediction. This is used to relax the constraints on
418      consuming all the training samples to keep compatibility till we
419      support partial batches. If none, then partial batches are not allowed.
420
421  Returns:
422    steps: The steps or steps_per_epoch argument depending on if a user is
423        calling `fit`, `evaluate` or `predict`. If the is_training flag is set
424        we don't require the number of samples to be used completely.
425    batch_size: The batch size to be used in model iterations.
426
427  Raises:
428    ValueError: If the number of batches or steps evaluates to 0.
429
430  """
431  num_samples = first_x_value.shape[0]
432  # TODO(b/118776054): Use global batch size for Keras/DS support.
433  # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
434  use_per_replica_batch = not global_batch_size_supported(
435      distribution_strategy)
436
437  # Partial batches are allowed for training as we repeat the
438  # dataset when converting numpy arrays into a dataset.
439  # For other modes uneven batch sizes are not allowed except
440  # for `predict()` on TPUStrategy.
441  allow_partial_batch = (mode == ModeKeys.TRAIN or
442                         (mode == ModeKeys.PREDICT
443                          and is_tpu_strategy(distribution_strategy)))
444
445  if steps is None:
446    if batch_size is None:
447      # If neither the batch size or number of steps are set. We choose the
448      # global batch size as the minimum of number of samples and 32. 32 is
449      # chosen to provide backward compatibility.
450      global_batch_size = min(num_samples, 32)
451    else:
452      # If the user provided the batch size we need to handle the case
453      # between different strategies that use the global/per-replica batch size
454      global_batch_size = batch_size
455      if use_per_replica_batch:
456        global_batch_size *= distribution_strategy.num_replicas_in_sync
457    if allow_partial_batch:
458      steps = np.ceil(num_samples / global_batch_size).astype(int)
459    else:
460      if num_samples % global_batch_size:
461        raise ValueError('The number of samples %s is not divisible by '
462                         'batch size %s.' % (num_samples, global_batch_size))
463      steps = num_samples // global_batch_size
464  else:
465    if batch_size is None:
466      # We calculate the batch size based on the number of steps specified
467      if num_samples % steps:
468        raise ValueError('The number of samples %s is not divisible by '
469                         'steps %s. Please change the number of steps to a '
470                         'value that can consume all the samples' % (
471                             num_samples, steps))
472      global_batch_size = num_samples // steps
473    else:
474      # If the user provided the batch size we need to handle the case
475      # between different strategies that use the global/per-replica batch size
476      global_batch_size = batch_size
477      if use_per_replica_batch:
478        global_batch_size *= distribution_strategy.num_replicas_in_sync
479
480      min_num_samples = global_batch_size * steps
481      if allow_partial_batch:
482        min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
483
484      if num_samples < min_num_samples:
485        raise ValueError('Number of samples %s is less than samples required '
486                         'for specified batch_size %s and steps %s' % (
487                             num_samples, global_batch_size, steps))
488
489  # We need to return the per replica or global batch size based on the strategy
490  if use_per_replica_batch:
491    if global_batch_size % distribution_strategy.num_replicas_in_sync:
492      raise ValueError(
493          'The batch size (%s) could not be sharded evenly across the sync '
494          'replicas (%s) in the distribution strategy.' % (
495              global_batch_size, distribution_strategy.num_replicas_in_sync))
496    batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync
497  else:
498    batch_size = global_batch_size
499
500  return steps, batch_size
501
502
503def get_batch_dimension(iterator):
504  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator))
505  # Take the batch size from the first element, as it should be the same for
506  # all.
507  dims = shapes[0].dims
508  return dims[0] if dims else None
509
510
511def list_to_tuple(maybe_list):
512  """Datasets treat lists specially, so switch them to tuples."""
513  if isinstance(maybe_list, list):
514    return tuple(maybe_list)
515  return maybe_list
516
517
518def get_iterator(dataset, distribution_strategy):
519  with distribution_strategy.scope():
520    iterator = distribution_strategy.make_dataset_iterator(dataset)
521  initialize_iterator(iterator, distribution_strategy)
522  return iterator
523
524
525def initialize_iterator(iterator, distribution_strategy):
526  with distribution_strategy.scope():
527    init_op = control_flow_ops.group(iterator.initialize())
528    if not context.executing_eagerly():
529      K.get_session((init_op,)).run(init_op)
530
531
532def _get_input_from_iterator(iterator, model):
533  """Get elements from the iterator and verify the input shape and type."""
534  next_element = iterator.get_next()
535
536  if len(nest.flatten(next_element)) == len(model.inputs):
537    x = next_element
538    y = None
539    sample_weights = None
540  elif len(nest.flatten(next_element)) == (len(model.inputs) +
541                                           len(model.outputs)):
542    x, y = next_element
543    sample_weights = None
544  else:
545    x, y, sample_weights = next_element
546
547  # Validate that all the elements in x and y are of the same type and shape.
548  validate_distributed_dataset_inputs(
549      model._distribution_strategy, x, y, sample_weights)
550  return x, y, sample_weights
551
552
553def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
554  """Prepare feed values to the model execution function.
555
556  Arguments:
557    model: Model to prepare feed values for.
558    inputs: List or dict of model inputs.
559    targets: Optional list of model targets.
560    sample_weights: Optional list of sample weight arrays.
561    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
562
563  Returns:
564    Feed values for the model in the given mode.
565  """
566  strategy = model._distribution_strategy
567  inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
568  inputs = flatten_perdevice_values(strategy, inputs)
569  targets = flatten_perdevice_values(strategy, targets)
570  # Expand 1-dimensional inputs.
571  # TODO(b/124535720): Remove once this standarize data logic is shared with
572  # main flow.
573  inputs, targets = nest.map_structure(training_utils.standardize_single_array,
574                                       (inputs, targets))
575  if mode == ModeKeys.PREDICT:
576    sample_weights = []
577    targets = []
578  else:
579    sample_weights = [
580        None for _ in range(len(model.outputs) * strategy.num_replicas_in_sync)
581    ]
582  ins = inputs + targets + sample_weights
583  if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(),
584                                               int):
585    ins += [True]
586  return ins
587
588
589def _custom_compile_for_predict(model):
590  """Custom compile for TPU predict mode."""
591  if not model.built:
592    # Model is not compilable because it does not know its number of inputs
593    # and outputs, nor their shapes and names. We will compile after the first
594    # time the model gets called on training data.
595    return
596  model._is_compiled = True
597  model.total_loss = None
598  model.train_function = None
599  model.test_function = None
600  model.predict_function = None
601
602
603def _build_network_on_replica(model, mode, inputs=None, targets=None):
604  """Build an updated model on replicas.
605
606  We create a new Keras model while sharing the variables from the old graph.
607  Building a new sub-graph is required since the original keras model creates
608  placeholders for the input and the output that are not accessible till we
609  call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.
610
611  The sharing of weights and layers between the old and the new model gaurantee
612  that we're using Strategy variables and any updates on either model are
613  reflected correctly in callbacks and loop iterations.
614
615  We need to make sure we share the optimizers between the old and the new model
616  as well so that optimizer state is not lost if the user is running fit
617  multiple times.
618
619  Args:
620    model: Model to be replicated across Replicas
621    mode: Which of fit/eval/predict is building the distributed network
622    inputs: Input variables to be passed to the model
623    targets: Target tensor to be passed to model.compile
624
625  Returns:
626    A new model with shared layers with the old model.
627  """
628  # Need to do imports here since we run into a circular dependency error.
629  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
630  from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top
631
632  # We rely on the internal methods to avoid having share_weights weights in the
633  # public API.
634  if isinstance(model, sequential.Sequential):
635    updated_model = models._clone_sequential_model(model, input_tensors=inputs,
636                                                   share_weights=True)
637  else:
638    updated_model = models._clone_functional_model(model, input_tensors=inputs,
639                                                   share_weights=True)
640
641  # Recast all low precision outputs back to float32 since we only casted
642  # the inputs to bfloat16 and not targets. This is done so that we can preserve
643  # precision when calculating the loss value.
644  def _upcast_low_precision_outputs(output):
645    if output.dtype == dtypes.bfloat16:
646      return math_ops.cast(output, dtypes.float32)
647    else:
648      return output
649  updated_model.outputs = [_upcast_low_precision_outputs(o)
650                           for o in updated_model.outputs]
651
652  if isinstance(targets, tuple):
653    targets = nest.flatten(targets)
654
655  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
656    _custom_compile_for_predict(updated_model)
657  else:
658    updated_model.compile(
659        model.optimizer,
660        model.loss,
661        metrics=metrics_module.clone_metrics(model._compile_metrics),
662        loss_weights=model.loss_weights,
663        sample_weight_mode=model.sample_weight_mode,
664        weighted_metrics=metrics_module.clone_metrics(
665            model._compile_weighted_metrics),
666        target_tensors=targets)
667  return updated_model
668
669
670def _build_distributed_network(model, strategy, mode, inputs=None,
671                               targets=None):
672  """Create a cloned model on each replica."""
673  with K.get_graph().as_default(), strategy.scope():
674    distributed_model = strategy.extended.call_for_each_replica(
675        _build_network_on_replica,
676        args=(model, mode, inputs, targets))
677    set_distributed_model(model, mode, distributed_model)
678
679
680def _clone_and_build_model(model, mode, inputs=None, targets=None):
681  """Clone and build the given keras_model."""
682  # We need to set the import here since we run into a circular dependency
683  # error.
684  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
685  cloned_model = models.clone_model(model, input_tensors=inputs)
686
687  # Compile and build model.
688  if isinstance(model.optimizer, optimizers.TFOptimizer):
689    optimizer = model.optimizer
690  else:
691    optimizer_config = model.optimizer.get_config()
692    optimizer = model.optimizer.__class__.from_config(optimizer_config)
693
694  # Recast all low precision outputs back to float32 since we only casted
695  # the inputs to bfloat16 and not targets. This is done so that we can preserve
696  # precision when calculating the loss value.
697  def _upcast_low_precision_outputs(output):
698    if output.dtype == dtypes.bfloat16:
699      return math_ops.cast(output, dtypes.float32)
700    else:
701      return output
702  cloned_model.outputs = [_upcast_low_precision_outputs(o)
703                          for o in cloned_model.outputs]
704
705  if isinstance(targets, tuple):
706    targets = nest.flatten(targets)
707  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
708    _custom_compile_for_predict(cloned_model)
709  else:
710    cloned_model.compile(
711        optimizer,
712        model.loss,
713        metrics=metrics_module.clone_metrics(model._compile_metrics),
714        loss_weights=model.loss_weights,
715        sample_weight_mode=model.sample_weight_mode,
716        weighted_metrics=metrics_module.clone_metrics(
717            model._compile_weighted_metrics),
718        target_tensors=targets)
719  return cloned_model
720
721
722def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None):
723  """Create a cloned model on each replica."""
724  with K.get_graph().as_default(), strategy.scope():
725    distributed_model = strategy.extended.call_for_each_replica(
726        _clone_and_build_model, args=(model, mode, inputs, targets))
727    set_distributed_model(model, mode, distributed_model)
728  if mode == ModeKeys.TRAIN:
729    model._make_callback_model(distributed_model)
730
731
732def _make_execution_function(model, mode):
733  """Makes or reuses function to run one step of distributed model execution."""
734  strategy = model._distribution_strategy
735
736  distributed_model = get_distributed_model(model, mode)
737  # If distributed model for a particular `mode` is already built, use the
738  # `_distribution_function` on that distributed model.
739  if distributed_model:
740    return distributed_model._distributed_function
741
742  # If distributed_model is not built, create one for `mode`.
743  if model._compile_distribution:
744    clone_model_on_replicas(model, strategy, mode)
745  else:
746    _build_distributed_network(model, strategy, mode)
747
748  # We've just created the distributed model. So `distributed_model` should be
749  # not None.
750  distributed_model = get_distributed_model(model, mode)
751  assert distributed_model
752
753  # Also create an execution fuction on that distributed model.
754  if context.executing_eagerly():
755    distributed_function = _make_eager_execution_function(model, mode)
756  else:
757    distributed_function = _make_graph_execution_function(model, mode)
758
759  # We cache the distributed execution function on the model since creating
760  # distributed models and exection functions are expensive.
761  distributed_model._distributed_function = distributed_function
762  return distributed_function
763
764
765def _make_graph_execution_function(model, mode):
766  """Makes function to run one step of distributed model in graph mode."""
767
768  def _per_device_function(model):
769    f = model._make_execution_function(mode)
770    return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
771
772  strategy = model._distribution_strategy
773  with strategy.scope():
774    # Create train ops on each of the devices when we call
775    # `_per_device_fit_function`.
776    (grouped_inputs, grouped_outputs, grouped_updates,
777     grouped_session_args) = strategy.extended.call_for_each_replica(
778         _per_device_function, args=(get_distributed_model(model, mode),))
779
780    # Initialize the variables in the replicated model. This is necessary for
781    # multi-worker training because on some workers, initialization is not
782    # needed. This method does initialization or waiting for initialization
783    # according to the context object of distribute coordinator.
784    init_restore_or_wait_for_variables()
785
786    # Unwrap all the per device values returned from `call_for_each_replica`.
787    # Unwrapping per device values gives you a list of values that can be
788    # used to construct a new train function that is composed of update ops on
789    # all the devices over which the model is distributed.
790    (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values(
791        strategy,
792        grouped_inputs,
793        grouped_outputs,
794        grouped_updates,
795        grouped_session_args,
796        with_loss_tensor=(mode != ModeKeys.PREDICT))
797
798    return K.function(
799        all_inputs,
800        all_outputs,
801        updates=all_updates,
802        name='distributed_{}_function'.format(mode),
803        **all_session_args)
804
805
806def _make_eager_execution_function(model, mode):
807  """Makes function to run one step of distributed model eager execution."""
808  def _per_device_function(model):
809    f = model._make_execution_function(mode)
810    return (f.inputs, f.outputs)
811
812  # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
813  # the global one.
814  strategy = model._distribution_strategy
815  global_graph = K.get_graph()
816
817  with global_graph.as_default(), strategy.scope():
818    # First we gather the relevant portions of the model across all replicas.
819    # `K._scratch_graph(global_graph)` signals to Keras that it should not
820    # lift to a separate graph when creating the per-replica functions.
821    with K._scratch_graph(global_graph):
822      # Create train ops on each of the devices when we call
823      # `_per_device_fit_function`.
824      grouped = strategy.extended.call_for_each_replica(
825          _per_device_function, args=(get_distributed_model(model, mode),))
826      grouped_inputs, grouped_outputs = grouped
827
828      # Unwrap all the per device values returned from `call_for_each_replica`.
829      # Unwrapping per device values gives you a list of values that can be
830      # used to construct a new train function that is composed of
831      # inputs/outputs on all the devices over which the model is distributed.
832      (all_inputs, all_outputs, _, _) = unwrap_values(
833          strategy,
834          grouped_inputs,
835          grouped_outputs,
836          with_loss_tensor=(mode != ModeKeys.PREDICT))
837
838    # Finally, a joint Keras function is created; this one will be created in
839    # a separate FuncGraph.
840    return K.function(
841        all_inputs,
842        all_outputs,
843        name='eager_distributed_{}_function'.format(mode))
844
845
846def _copy_weights_to_distributed_model(original_model, mode):
847  """Copies weights from original model to distributed models."""
848  strategy = original_model._distribution_strategy
849  distributed_model = get_distributed_model(original_model, mode)
850  if strategy:
851    # Copy the weights from the original model to each of the replicated
852    # models.
853    orig_model_weights = original_model.get_weights()
854    first_model = strategy.unwrap(distributed_model)[0]
855    set_weights(strategy, first_model, orig_model_weights)
856
857
858def _copy_weights_to_original_model(model, mode):
859  """Copies weights from first distributed model back to original model."""
860  if model._distribution_strategy and mode == ModeKeys.TRAIN:
861    distributed_model = get_distributed_model(model, mode)
862    updated_weights = model._distribution_strategy.unwrap(
863        distributed_model)[0].get_weights()
864    model.set_weights(updated_weights)
865
866
867def _per_device_aggregate_batch(batch_outs, model, mode):
868  """Aggregates the per-device batch-level outputs from a distributed step."""
869  if model._distribution_strategy is not None and mode == ModeKeys.PREDICT:
870    total_batch_outs = []
871    for i in range(len(model.outputs)):
872      num_replicas = model._distribution_strategy.num_replicas_in_sync
873      nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
874      total_batch_outs.append(np.concatenate(nest.flatten(nested_outs)))
875    return total_batch_outs
876  return batch_outs
877
878
879def _reset_metrics(model):
880  if model._distribution_strategy:
881    for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
882      distributed_model = get_distributed_model(model, mode)
883      if distributed_model:
884        first_model = model._distribution_strategy.unwrap(distributed_model)[0]
885        first_model.reset_metrics()
886
887
888def get_distributed_model(model, mode):
889  key = _generate_cache_key(mode)
890  return model._distributed_model_cache.get(key, None)
891
892
893def set_distributed_model(model, mode, distributed_model):
894  key = _generate_cache_key(mode)
895  model._distributed_model_cache[key] = distributed_model
896
897
898def _generate_cache_key(mode):
899  key = hash(mode)
900  return key
901
902
903@tf_contextlib.contextmanager
904def distributed_scope(strategy, learning_phase):
905  with strategy.scope(), K.learning_phase_scope(learning_phase):
906    yield
907
908
909def filter_distributed_callbacks(callbacks_list):
910  """Filter Callbacks based on the worker context when running multi-worker.
911
912  Arguments:
913    callbacks_list: A list of `Callback` instances.
914
915  Returns:
916    The list of `Callback` instances that should be run on this worker.
917  """
918
919  if not K.in_multi_worker_mode():
920    raise ValueError(
921        'filter_distributed_callbacks() should only be called when Keras '
922        'is in multi worker mode.')
923
924  worker_context = dc_context.get_current_worker_context()
925  callbacks_list = callbacks_list or []
926  if not [
927      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
928  ]:
929    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
930    # fails to.
931    logging.warning('ModelCheckpoint callback is not provided. '
932                    'Workers will need to restart training if any fails.')
933  # TODO(rchao): Add similar warning for restoring callback (to be designed).
934
935  if callbacks_list is None or worker_context.is_chief:
936    return callbacks_list
937
938  # Some Callbacks should only run on the chief worker.
939  return [
940      callback for callback in callbacks_list if not callback._chief_worker_only
941  ]  # pylint: disable=protected-access
942