• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
16# pylint:disable=protected-access
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import numpy as np
24
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.distribute import distribute_coordinator_context as dc_context
28from tensorflow.python.distribute import distribution_strategy_context as ds_context
29from tensorflow.python.distribute import multi_worker_util
30from tensorflow.python.distribute import reduce_util
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import sparse_tensor
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.keras import backend as K
38from tensorflow.python.keras import callbacks
39from tensorflow.python.keras import metrics as metrics_module
40from tensorflow.python.keras import optimizers
41from tensorflow.python.keras.engine import training_utils
42from tensorflow.python.keras.optimizer_v2 import optimizer_v2
43from tensorflow.python.keras.utils.mode_keys import ModeKeys
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import sparse_ops
47from tensorflow.python.ops import variables
48from tensorflow.python.ops.ragged import ragged_concat_ops
49from tensorflow.python.ops.ragged import ragged_tensor
50from tensorflow.python.platform import tf_logging as logging
51from tensorflow.python.util import nest
52from tensorflow.python.util import tf_contextlib
53
54
55def set_weights(distribution_strategy, dist_model, weights):
56  """Sets the weights of the replicated models.
57
58  The weights of the replicated models are set to the weights of the original
59  model. The weights of the replicated model are Mirrored variables and hence
60  we need to use the `update` call within a DistributionStrategy scope.
61
62  Args:
63    distribution_strategy: DistributionStrategy used to distribute training
64        and validation.
65    dist_model: The replicated models on the different devices.
66    weights: The weights of the original model.
67  """
68  assign_ops = []
69  for layer in dist_model.layers:
70    num_param = len(layer.weights)
71    layer_weights = weights[:num_param]
72    for sw, w in zip(layer.weights, layer_weights):
73      if ops.executing_eagerly_outside_functions():
74        sw.assign(w)
75      else:
76        assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
77    weights = weights[num_param:]
78
79  if not ops.executing_eagerly_outside_functions():
80    K.get_session(assign_ops).run(assign_ops)
81
82
83def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
84                  grouped_updates=None, grouped_session_args=None,
85                  with_loss_tensor=False):
86  """Unwrap the list of values contained in the PerReplica parameters.
87
88  This function calls `flatten_per_replica_values` to parse each of the input
89  parameters into a list of values on the different devices. If we set
90  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
91  the different devices to give us one loss tensor.
92
93  Args:
94    distribution_strategy: DistributionStrategy used to distribute training and
95        validation.
96    grouped_inputs: PerReplica inputs returned from the train or test function
97        that we ran on each device.
98    grouped_outputs: PerReplica outputs returned from the train or test function
99        that we ran on each device.
100    grouped_updates: PerReplica updates returned from the train or test function
101        that we ran on each device.
102    grouped_session_args: PerReplica session args returned from the train or
103        test function that we ran on each device.
104    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
105        tensor as one of the outputs.
106
107  Returns:
108    Values of each of the PerReplica parameters.
109
110  """
111  # Unwrap per device values returned from each model's train function.
112  # This will be used to construct the main train function.
113  all_inputs = flatten_per_replica_values(distribution_strategy,
114                                          grouped_inputs)
115  all_outputs = unwrap_outputs(distribution_strategy, grouped_outputs,
116                               with_loss_tensor)
117
118  if grouped_updates:
119    all_updates = flatten_per_replica_values(distribution_strategy,
120                                             grouped_updates)
121  else:
122    all_updates = None
123
124  all_session_args = {}
125  if grouped_session_args:
126    grouped_feed_dict = grouped_session_args.get('feed_dict')
127    if grouped_feed_dict:
128      all_session_args['feed_dict'] = flatten_per_replica_values(
129          distribution_strategy, grouped_feed_dict)
130
131    grouped_fetches = grouped_session_args.get('fetches')
132    if grouped_fetches:
133      all_session_args['fetches'] = flatten_per_replica_values(
134          distribution_strategy, grouped_fetches)
135
136  # TODO(priyag): Return only non empty/None values
137  return all_inputs, all_outputs, all_updates, all_session_args
138
139
140def unwrap_output_dict(strategy, grouped_outputs, mode):
141  """Unwrap the list of outputs contained in the PerReplica parameters."""
142  if mode == ModeKeys.PREDICT:
143    return flatten_per_replica_values(strategy, grouped_outputs)
144
145  # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
146  # the output is as same structure as model output. They need to be treated
147  # differently
148  total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
149                               grouped_outputs['total_loss'][0], axis=None)
150  output_losses = flatten_per_replica_values(strategy,
151                                             grouped_outputs['output_losses'])
152  metrics = flatten_per_replica_values(strategy,
153                                       grouped_outputs['metrics'])
154  batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
155                               grouped_outputs['batch_size'], axis=None)
156  if (is_tpu_strategy(strategy) and
157      ops.executing_eagerly_outside_functions()):
158    # Choose 1 value per replica in the TPU case since all replicas produce the
159    # same output.
160    # We only do this in eager mode for now since this function is used in
161    # both graph and eager mode and in the graph case we currently don't use
162    # experimental_run so would need to be removed when we converge the graph
163    # code path as well.
164    output_losses = output_losses[::strategy.num_replicas_in_sync]
165    metrics = metrics[::strategy.num_replicas_in_sync]
166  return {'total_loss': [total_loss],
167          'output_losses': output_losses,
168          'metrics': metrics,
169          'batch_size': batch_size}
170
171
172def unwrap_outputs(distribution_strategy, grouped_outputs,
173                   with_loss_tensor=False):
174  """Unwrap the list of outputs contained in the PerReplica parameters.
175
176  This function calls `flatten_per_replica_values` to parse each of the input
177  parameters into a list of outputs on the different devices. If we set
178  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
179  the different devices to give us one loss tensor.
180
181  Args:
182    distribution_strategy: DistributionStrategy used to distribute training and
183        validation.
184    grouped_outputs: PerReplica outputs returned from the train or test function
185        that we ran on each device.
186    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
187        tensor as one of the outputs.
188
189  Returns:
190    Values of each of the PerReplica outputs.
191
192  """
193  if not with_loss_tensor:
194    return flatten_per_replica_values(distribution_strategy,
195                                      grouped_outputs)
196
197  if not isinstance(grouped_outputs, list):
198    grouped_outputs = [grouped_outputs]
199  # reduce loss tensor before adding it to the list of fetches
200  loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
201                                      grouped_outputs[0], axis=None)
202  all_outputs = flatten_per_replica_values(distribution_strategy,
203                                           grouped_outputs[1:])
204  if (is_tpu_strategy(distribution_strategy) and
205      ops.executing_eagerly_outside_functions()):
206    # Choose 1 value per replica in the TPU case since all replicas produce the
207    # same output.
208    # We only do this in eager mode for now since this function is used in
209    # both graph and eager mode and in the graph case we currently don't use
210    # experimental_run so would need to be removed when we converge the graph
211    # code path as well.
212    all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
213  return [loss] + all_outputs
214
215
216def flatten_per_replica_values(distribution_strategy, per_replica_values):
217  """Unwraps and flattens a nest of PerReplica parameters.
218
219  PerReplica values have one value associated with each device. Each entry in
220  the PerReplica dict has a device `key` and the corresponding value on the
221  device as the `value`. In this function we take a PerReplica value or a list
222  of PerReplica values and return all the values in the PerReplica dict.
223
224  Args:
225    distribution_strategy: DistributionStrategy used to distribute training and
226      validation.
227    per_replica_values: List of PerReplica object or a single PerReplica object.
228
229  Returns:
230    List of values of all the PerReplica objects.
231
232  """
233  # pylint: disable=g-complex-comprehension
234  # This function takes a PerReplica object or a list of PerReplica objects and
235  # returns all the values associated with it.
236  return [e for flattened in nest.flatten(per_replica_values)
237          for e in distribution_strategy.unwrap(flattened)]
238
239
240def validate_callbacks(input_callbacks, optimizer):
241  """Validate whether given callbacks are supported by DistributionStrategy.
242
243  Args:
244    input_callbacks: List of callbacks passed by the user to fit.
245    optimizer: Optimizer instance used to train the model.
246
247  Raises:
248    ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
249        callbacks passed.
250    ValueError: If `write_grads` is one of the parameters passed as part of the
251        TensorBoard callback.
252  """
253  if input_callbacks:
254    for callback in input_callbacks:
255      if isinstance(callback, (callbacks.LearningRateScheduler,
256                               callbacks.ReduceLROnPlateau)):
257
258        if not isinstance(optimizer, optimizer_v2.OptimizerV2):
259          raise ValueError('You must specify a Keras Optimizer V2 when using '
260                           '%s callback with DistributionStrategy.' % callback)
261
262      # If users want to use the TensorBoard callback they cannot use certain
263      # features of the callback that involve accessing model attributes and
264      # running ops.
265      if isinstance(callback, callbacks.TensorBoard):
266        if getattr(callback, 'write_grads', False):
267          logging.warning(
268              UserWarning(
269                  '`write_grads` in the TensorBoard callback is not supported '
270                  'when using DistributionStrategy. Setting `write_grads` '
271                  'to `False`.'))
272          callback.write_grads = False
273
274
275def validate_distributed_dataset_inputs(distribution_strategy, x, y,
276                                        sample_weights=None):
277  """Validate all the components of a DistributedValue Dataset input.
278
279  Args:
280    distribution_strategy: The current DistributionStrategy used to call
281        `fit`/`evaluate`.
282    x: Input Dataset DistributedValue object. For example, when we use
283        `MirroredStrategy` this is a PerReplica object with a tensor for each
284        device set in the dict. x can also be a tuple or dict. The keys of the
285        dict should match the names of the input layers of the model.
286    y: Target Dataset DistributedValue object. For example, when we use
287        `MirroredStrategy` this is a PerReplica object with a tensor for each
288        device set in the dict. y can also be a tuple or dict. The keys of the
289        dict should match the names of the output layers of the model.
290    sample_weights: Sample weights Dataset DistributedValue object. For example,
291        when we use `MirroredStrategy` this is a PerReplica object with a tensor
292        for each device set in the dict.
293
294  Returns:
295    The unwrapped values list of the x and y DistributedValues inputs.
296
297  Raises:
298    ValueError: If x and y do not have support for being evaluated as tensors.
299        or if x and y contain elements that are not tensors or if x and y
300        contain elements that have a shape or dtype mismatch.
301  """
302  # If the input and target used to call the model are not dataset tensors,
303  # we need to raise an error. When using a DistributionStrategy, the input
304  # and targets to a model should be from a `tf.data.Dataset`.
305
306  # If each element of x and y are not tensors, we cannot standardize and
307  # validate the input and targets.
308  x_values_list = validate_per_replica_inputs(distribution_strategy, x)
309
310  if y is not None:
311    y_values_list = validate_per_replica_inputs(distribution_strategy, y)
312  else:
313    y_values_list = None
314
315  if sample_weights is not None:
316    sample_weights_list = validate_per_replica_inputs(distribution_strategy,
317                                                      sample_weights)
318  else:
319    sample_weights_list = None
320
321  # Return the unwrapped values to avoid calling `unwrap` a second time.
322  return x_values_list, y_values_list, sample_weights_list
323
324
325def validate_per_replica_inputs(distribution_strategy, x):
326  """Validates PerReplica dataset input list.
327
328  Args:
329    distribution_strategy: The current DistributionStrategy used to call
330      `fit`, `evaluate` and `predict`.
331    x: A list of PerReplica objects that represent the input or
332      target values.
333
334  Returns:
335    List containing the first element of each of the PerReplica objects in
336    the input list.
337
338  Raises:
339    ValueError: If any of the objects in the `per_replica_list` is not a tensor.
340
341  """
342  # Convert the inputs and targets into a list of PerReplica objects.
343  per_replica_list = nest.flatten(x, expand_composites=True)
344  x_values_list = []
345  for x in per_replica_list:
346    if not tensor_util.is_tensor(x):
347      raise ValueError('Dataset input to the model should be tensors instead '
348                       'they are of type {}'.format(type(x)))
349
350    # At this point both x and y contain tensors in the `DistributedValues`
351    # structure.
352    x_values = distribution_strategy.unwrap(x)
353
354    if not context.executing_eagerly():
355      # Validate that the shape and dtype of all the elements in x are the same.
356      validate_all_tensor_shapes(x, x_values)
357    validate_all_tensor_types(x, x_values)
358
359    x_values_list.append(x_values[0])
360  return x_values_list
361
362
363def validate_all_tensor_types(x, x_values):
364  x_dtype = x_values[0].dtype
365  for i in range(1, len(x_values)):
366    if x_dtype != x_values[i].dtype:
367      raise ValueError('Input tensor dtypes do not match for distributed tensor'
368                       ' inputs {}'.format(x))
369
370
371def validate_all_tensor_shapes(x, x_values):
372  # Validate that the shape of all the elements in x have the same shape
373  x_shape = x_values[0].shape.as_list()
374  for i in range(1, len(x_values)):
375    if x_shape != x_values[i].shape.as_list():
376      raise ValueError('Input tensor shapes do not match for distributed tensor'
377                       ' inputs {}'.format(x))
378
379
380def _wait_for_variable_initialization(session):
381  """Utility to wait for variables to be initialized."""
382  all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
383  candidate_vars = []
384  for v in all_variables:
385    if not getattr(v, '_keras_initialized', False):
386      candidate_vars.append(v)
387
388  if not candidate_vars:
389    return
390
391  while True:
392    is_initialized = session.run(
393        [variables.is_variable_initialized(v) for v in candidate_vars])
394    uninitialized_vars = []
395    for flag, v in zip(is_initialized, candidate_vars):
396      if not flag:
397        uninitialized_vars.append(v)
398      v._keras_initialized = True  # pylint: disable=protected-access
399    if not uninitialized_vars:
400      break
401
402
403def init_restore_or_wait_for_variables():
404  """Initialize or restore variables or wait for variables to be initialized."""
405  session = K._get_session()  # pylint: disable=protected-access
406  if not multi_worker_util.has_worker_context(
407  ) or multi_worker_util.should_load_checkpoint():
408    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
409    K._initialize_variables(session)  # pylint: disable=protected-access
410  else:
411    _wait_for_variable_initialization(session)
412
413
414def validate_inputs(x, y):
415  """Validate inputs when using DistributionStrategy.
416
417  Args:
418    x: Model Inputs.
419    y: Model Targets.
420
421  Raises:
422    ValueError: if input is not a Dataset or a numpy array(when we use
423      MirroredStrategy).
424  """
425  if (isinstance(x, iterator_ops.Iterator) or
426      isinstance(y, iterator_ops.Iterator)):
427    raise ValueError('`DistributionStrategy` does not support inputs of type '
428                     'Iterator. You must pass a `tf.data.Dataset` object or a '
429                     'numpy array as input.')
430
431
432# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
433# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
434# no longer needed.
435def global_batch_size_supported(distribution_strategy):
436  return distribution_strategy.extended._global_batch_size  # pylint: disable=protected-access
437
438
439# TODO(sourabhbajaj): Remove this once we use the same API for all strategies.
440def is_tpu_strategy(strategy):
441  """We're executing TPU Strategy."""
442  return (strategy is not None and
443          strategy.__class__.__name__.startswith('TPUStrategy'))
444
445
446def is_dataset_shape_fully_defined(dataset):
447  """Returns whether a dataset contains a final partial batch."""
448  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
449  unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
450  return not unknown_shapes
451
452
453def process_batch_and_step_size(strategy,
454                                inputs,
455                                batch_size,
456                                steps_per_epoch,
457                                mode,
458                                validation_split=0.):
459  """Process the batch size and step size based on input and dist strategy."""
460  first_x_value = nest.flatten(inputs)[0]
461  if isinstance(first_x_value, np.ndarray):
462    num_samples = first_x_value.shape[0]
463    if validation_split and 0. < validation_split < 1.:
464      num_samples = int(num_samples * (1 - validation_split))
465    # Until support for partial batch is implemented across all
466    # functions and distribution strategy, we pass `mode` to selectively
467    # relax the constraint to consume all the training samples.
468    steps_per_epoch, batch_size = get_input_params(
469        strategy, num_samples, steps_per_epoch, batch_size, mode=mode)
470  return batch_size, steps_per_epoch
471
472
473def get_input_params(distribution_strategy,
474                     num_samples,
475                     steps,
476                     batch_size,
477                     mode=None):
478  """Calculate the number of batches and steps/steps_per_epoch.
479
480  Args:
481    distribution_strategy: The DistributionStrategy used to compile the model.
482    num_samples: The number of samples from which we determine the batch size
483      and steps.
484    steps:  The specified number of steps.
485    batch_size: The specified batch_size.
486    mode: ModeKey representing whether input will be used for training,
487      evaluation, or prediction. This is used to relax the constraints on
488      consuming all the training samples to keep compatibility till we support
489      partial batches. If none, then partial batches are not allowed.
490
491  Returns:
492    steps: The steps or steps_per_epoch argument depending on if a user is
493        calling `fit`, `evaluate` or `predict`. If the is_training flag is set
494        we don't require the number of samples to be used completely.
495    batch_size: The batch size to be used in model iterations.
496
497  Raises:
498    ValueError: If the number of batches or steps evaluates to 0.
499
500  """
501  # TODO(b/118776054): Use global batch size for Keras/DS support.
502  # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
503  use_per_replica_batch = not global_batch_size_supported(
504      distribution_strategy)
505
506  # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for
507  # `fit()` on TPUStrategy.
508  # In graph mode, the zero batch case in batch norm is not handled due to
509  # XLA-GPU regression. Uneven batch sizes are not allowed except
510  # for `test()` and `predict()` on TPUStrategy.
511  if context.executing_eagerly():
512    allow_partial_batch = (mode != ModeKeys.TRAIN or
513                           not is_tpu_strategy(distribution_strategy))
514  else:
515    allow_partial_batch = (mode == ModeKeys.TRAIN or
516                           ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST)
517                            and is_tpu_strategy(distribution_strategy)))
518
519  if steps is None:
520    if batch_size is None:
521      # If neither the batch size or number of steps are set. We choose the
522      # global batch size as the minimum of number of samples and 32. 32 is
523      # chosen to provide backward compatibility.
524      global_batch_size = min(num_samples, 32)
525    else:
526      # If the user provided the batch size we need to handle the case
527      # between different strategies that use the global/per-replica batch size
528      global_batch_size = batch_size
529      if use_per_replica_batch:
530        global_batch_size *= distribution_strategy.num_replicas_in_sync
531    if allow_partial_batch:
532      steps = np.ceil(num_samples / global_batch_size).astype(int)
533    else:
534      if num_samples % global_batch_size:
535        raise ValueError('The number of samples %s is not divisible by '
536                         'batch size %s.' % (num_samples, global_batch_size))
537      steps = num_samples // global_batch_size
538  else:
539    if batch_size is None:
540      # We calculate the batch size based on the number of steps specified
541      if num_samples % steps:
542        raise ValueError('The number of samples %s is not divisible by '
543                         'steps %s. Please change the number of steps to a '
544                         'value that can consume all the samples' % (
545                             num_samples, steps))
546      global_batch_size = num_samples // steps
547    else:
548      # If the user provided the batch size we need to handle the case
549      # between different strategies that use the global/per-replica batch size
550      global_batch_size = batch_size
551      if use_per_replica_batch:
552        global_batch_size *= distribution_strategy.num_replicas_in_sync
553
554      min_num_samples = global_batch_size * steps
555      if allow_partial_batch:
556        min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
557
558      if num_samples < min_num_samples:
559        raise ValueError('Number of samples %s is less than samples required '
560                         'for specified batch_size %s and steps %s' % (
561                             num_samples, global_batch_size, steps))
562
563  # We need to return the per replica or global batch size based on the strategy
564  if use_per_replica_batch:
565    if global_batch_size % distribution_strategy.num_replicas_in_sync:
566      raise ValueError(
567          'The batch size (%s) could not be sharded evenly across the sync '
568          'replicas (%s) in the distribution strategy.' % (
569              global_batch_size, distribution_strategy.num_replicas_in_sync))
570    batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync
571  else:
572    batch_size = global_batch_size
573
574  return steps, batch_size
575
576
577def get_batch_dimension(iterator):
578  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator))
579  # Take the batch size from the first element, as it should be the same for
580  # all.
581  dims = shapes[0].dims
582  return dims[0] if dims else None
583
584
585def get_iterator(dataset, distribution_strategy):
586  with distribution_strategy.scope():
587    iterator = distribution_strategy.make_dataset_iterator(dataset)
588  initialize_iterator(iterator, distribution_strategy)
589  return iterator
590
591
592def initialize_iterator(iterator, distribution_strategy):
593  with distribution_strategy.scope():
594    init_op = control_flow_ops.group(iterator.initializer)
595    if not context.executing_eagerly():
596      K.get_session((init_op,)).run(init_op)
597
598
599def _get_input_from_iterator(iterator, model):
600  """Get elements from the iterator and verify the input shape and type."""
601  next_element = iterator.get_next()
602
603  # `len(nest.flatten(x))` is going to not count empty elements such as {}.
604  # len(nest.flatten([[0,1,2], {}])) is 3 and not 4.   The `next_element` is
605  # going to get flattened in `_prepare_feed_values` to work around that. Empty
606  # elements are going to get filtered out as part of the flattening.
607  if len(nest.flatten(next_element)) == len(model.inputs):
608    x = next_element
609    y = None
610    sample_weights = None
611  elif len(nest.flatten(next_element)) == (len(model.inputs) +
612                                           len(model.outputs)):
613    x, y = next_element
614    sample_weights = None
615  else:
616    x, y, sample_weights = next_element
617
618  # Validate that all the elements in x and y are of the same type and shape.
619  validate_distributed_dataset_inputs(
620      model._distribution_strategy, x, y, sample_weights)
621  return x, y, sample_weights
622
623
624def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
625  """Prepare feed values to the model execution function.
626
627  Arguments:
628    model: Model to prepare feed values for.
629    inputs: List or dict of model inputs.
630    targets: Optional list of model targets.
631    sample_weights: Optional list of sample weight arrays.
632    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
633
634  Returns:
635    Feed values for the model in the given mode.
636  """
637  strategy = model._distribution_strategy
638  inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
639  if is_tpu_strategy(strategy):
640    if sample_weights is not None:
641      raise ValueError('TPUStrategy does not support sample weights.')
642
643  # When the inputs are dict, then we want to flatten it in the same order as
644  # the input layers, such that the data are fed into the input layers in the
645  # correct order.
646  if isinstance(inputs, dict):
647    inputs = [inputs[key] for key in model._feed_input_names]
648  if is_distributing_by_cloning(model):
649    inputs = flatten_per_replica_values(strategy, inputs)
650    targets = flatten_per_replica_values(strategy, targets)
651    # Expand 1-dimensional inputs.
652    # TODO(b/124535720): Remove once this standarize data logic is shared with
653    # main flow.
654    inputs, targets = nest.map_structure(
655        training_utils.standardize_single_array, (inputs, targets))
656  else:
657    inputs = training_utils.ModelInputs(inputs).as_list()
658
659  if mode == ModeKeys.PREDICT:
660    sample_weights = []
661    targets = []
662  elif sample_weights is not None and is_distributing_by_cloning(model):
663    if context.executing_eagerly() and not model._compile_distribution:
664      raise NotImplementedError('`sample_weight` is not supported when using '
665                                'tf.distribute.Strategy in eager mode and '
666                                'cloning=True.')
667    sample_weights = flatten_per_replica_values(strategy, sample_weights)
668
669  ins = [inputs, targets, sample_weights]
670  return tuple(ins)
671
672
673def is_distributing_by_cloning(model):
674  """Decide whether this model is going to be distributed via cloning.
675
676  We are going to distribute the model by cloning in graph mode.
677
678  Args:
679    model: Keras model to distribute.
680
681  Returns:
682    True if the `model` is going to be distributed using cloning and False
683    otherwise.
684  """
685  if (is_tpu_strategy(model._distribution_strategy) and
686      context.executing_eagerly):  # b/137580852
687    return False
688  elif ops.executing_eagerly_outside_functions():
689    return bool(model._compile_distribution)
690  return True
691
692
693def _custom_compile_for_predict(model):
694  """Custom compile for TPU predict mode."""
695  if not model.built:
696    # Model is not compilable because it does not know its number of inputs
697    # and outputs, nor their shapes and names. We will compile after the first
698    # time the model gets called on training data.
699    return
700  model._is_compiled = True
701  model.total_loss = None
702  model.train_function = None
703  model.test_function = None
704  model.predict_function = None
705
706
707def _build_network_on_replica(model, mode, inputs=None, targets=None):
708  """Build an updated model on replicas.
709
710  We create a new Keras model while sharing the variables from the old graph.
711  Building a new sub-graph is required since the original keras model creates
712  placeholders for the input and the output that are not accessible till we
713  call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.
714
715  The sharing of weights and layers between the old and the new model gaurantee
716  that we're using Strategy variables and any updates on either model are
717  reflected correctly in callbacks and loop iterations.
718
719  We need to make sure we share the optimizers between the old and the new model
720  as well so that optimizer state is not lost if the user is running fit
721  multiple times.
722
723  Args:
724    model: Model to be replicated across Replicas
725    mode: Which of fit/eval/predict is building the distributed network
726    inputs: Input variables to be passed to the model
727    targets: Target tensor to be passed to model.compile
728
729  Returns:
730    A new model with shared layers with the old model.
731  """
732  # Need to do imports here since we run into a circular dependency error.
733  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
734  from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top
735
736  # We rely on the internal methods to avoid having share_weights weights in the
737  # public API.
738  if isinstance(model, sequential.Sequential):
739    updated_model = models._clone_sequential_model(
740        model, input_tensors=inputs, layer_fn=models.share_weights)
741  else:
742    updated_model = models._clone_functional_model(
743        model, input_tensors=inputs, layer_fn=models.share_weights)
744    # Callable losses added directly to a functional Model need to be added
745    # here.
746    updated_model._callable_losses = model._callable_losses
747
748  # Recast all low precision outputs back to float32 since we only casted
749  # the inputs to bfloat16 and not targets. This is done so that we can preserve
750  # precision when calculating the loss value.
751  def _upcast_low_precision_outputs(output):
752    if output.dtype == dtypes.bfloat16:
753      return math_ops.cast(output, dtypes.float32)
754    else:
755      return output
756  updated_model.outputs = [_upcast_low_precision_outputs(o)
757                           for o in updated_model.outputs]
758
759  if isinstance(targets, tuple):
760    targets = nest.flatten(targets)
761
762  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
763    _custom_compile_for_predict(updated_model)
764  else:
765    updated_model.compile(
766        model.optimizer,
767        model.loss,
768        metrics=metrics_module.clone_metrics(model._compile_metrics),
769        loss_weights=model.loss_weights,
770        sample_weight_mode=model.sample_weight_mode,
771        weighted_metrics=metrics_module.clone_metrics(
772            model._compile_weighted_metrics),
773        target_tensors=targets)
774  return updated_model
775
776
777def _build_distributed_network(model, strategy, mode, inputs=None,
778                               targets=None):
779  """Create a cloned model on each replica."""
780  with K.get_graph().as_default(), strategy.scope():
781    distributed_model = strategy.extended.call_for_each_replica(
782        _build_network_on_replica,
783        args=(model, mode, inputs, targets))
784    set_distributed_model(model, mode, distributed_model)
785
786
787def _clone_and_build_model(model, mode, inputs=None, targets=None):
788  """Clone and build the given keras_model."""
789  # We need to set the import here since we run into a circular dependency
790  # error.
791  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
792  cloned_model = models.clone_model(model, input_tensors=inputs)
793
794  # Compile and build model.
795  if isinstance(model.optimizer, optimizers.TFOptimizer):
796    optimizer = model.optimizer
797  else:
798    optimizer_config = model.optimizer.get_config()
799    optimizer = model.optimizer.__class__.from_config(optimizer_config)
800
801  # Recast all low precision outputs back to float32 since we only casted
802  # the inputs to bfloat16 and not targets. This is done so that we can preserve
803  # precision when calculating the loss value.
804  def _upcast_low_precision_outputs(output):
805    if output.dtype == dtypes.bfloat16:
806      return math_ops.cast(output, dtypes.float32)
807    else:
808      return output
809  cloned_model.outputs = [_upcast_low_precision_outputs(o)
810                          for o in cloned_model.outputs]
811
812  if isinstance(targets, tuple):
813    targets = nest.flatten(targets)
814  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
815    _custom_compile_for_predict(cloned_model)
816  else:
817    cloned_model.compile(
818        optimizer,
819        model.loss,
820        metrics=metrics_module.clone_metrics(model._compile_metrics),
821        loss_weights=model.loss_weights,
822        sample_weight_mode=model.sample_weight_mode,
823        weighted_metrics=metrics_module.clone_metrics(
824            model._compile_weighted_metrics),
825        target_tensors=targets)
826  return cloned_model
827
828
829def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None):
830  """Create a cloned model on each replica."""
831  with K.get_graph().as_default(), strategy.scope():
832    distributed_model = strategy.extended.call_for_each_replica(
833        _clone_and_build_model, args=(model, mode, inputs, targets))
834    set_distributed_model(model, mode, distributed_model)
835  if mode == ModeKeys.TRAIN:
836    model._make_callback_model(distributed_model)
837
838
839def _make_execution_function(model, mode):
840  """Makes or reuses function to run one step of distributed model execution."""
841  if is_distributing_by_cloning(model):
842    return _make_execution_function_with_cloning(model, mode)
843
844  distributed_function = get_distributed_function(model, mode)
845  if distributed_function:
846    return distributed_function
847
848  distribution_function = _make_execution_function_without_cloning(model, mode)
849  set_distributed_function(model, mode, distribution_function)
850  return distribution_function
851
852
853def _make_execution_function_without_cloning(model, mode):
854  """Creates a function to run one step of distributed model execution."""
855  strategy = model._distribution_strategy
856
857  with strategy.scope():
858    per_replica_function = _make_replica_execution_function(model, mode)
859
860    def distributed_function(input_fn):
861      """A single step of the distributed execution across replicas."""
862      x, y, sample_weights = input_fn()
863      # Call `Model.{train,test,predict}_on_batch` on every replica passing
864      # PerReplicas as arguments.  On every replica inside this call, each
865      # PerReplica object will return the value for that replica.  The outputs
866      # are PerReplicas too.
867      outputs = strategy.experimental_run_v2(
868          per_replica_function, args=(x, y, sample_weights))
869      # Out of PerReplica outputs reduce or pick values to return.
870      all_outputs = unwrap_outputs(
871          strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
872      return all_outputs
873
874    if not model.run_eagerly:
875      distributed_function = def_function.function(distributed_function)
876      def execution_function(input_fn):
877        # `numpy` translates Tensors to values in Eager mode.
878        return [out.numpy() for out in distributed_function(input_fn)]
879    else:
880      execution_function = distributed_function
881
882    return execution_function
883
884
885def _make_replica_execution_function(model, mode):
886  """A single step of the distributed execution on a replica."""
887  if mode == ModeKeys.TRAIN:
888    func = model.train_on_batch
889  elif mode == ModeKeys.TEST:
890    func = model.test_on_batch
891  else:
892
893    def predict_on_batch(x, y=None, sample_weights=None):
894      del y, sample_weights
895      return model.predict_on_batch(x)
896
897    func = predict_on_batch
898
899  if mode != ModeKeys.PREDICT:
900    # `reset_metrics` is set to False to maintain stateful metrics across
901    # batch-level calls.
902    func = functools.partial(func, reset_metrics=False)
903
904  return func
905
906
907def _make_replicated_models_with_cloning(model, mode):
908  """Build models on each replica."""
909  strategy = model._distribution_strategy
910
911  # If distributed_model is not built, create one for `mode`.
912  if model._compile_distribution:
913    clone_model_on_replicas(model, strategy, mode)
914  else:
915    _build_distributed_network(model, strategy, mode)
916
917
918def _make_execution_function_with_cloning(model, mode):
919  """Clones or re-uses models to run one step of distributed model execution."""
920  distributed_model = get_distributed_model(model, mode)
921  # TODO(b/134069401): Create a cache for the distributed model and exec
922  # function that incorporates additional attributes to be part of the cache key
923  # than just the mode.
924  # If distributed model for a particular `mode` is already built, use the
925  # `_distribution_function` on that distributed model.
926  # If you have updated the sample_weight_mode on the model, then you will need
927  # to recompile metrics and recreate the execution function. This is indicated
928  # by the `_recompile_exec_function` property.
929  if (distributed_model and hasattr(distributed_model, '_distribution_function')
930      and not (hasattr(distributed_model, '_recompile_exec_function') and
931               distributed_model._recompile_exec_function)):
932    return distributed_model._distributed_function
933
934  if not distributed_model:
935    _make_replicated_models_with_cloning(model, mode)
936    distributed_model = get_distributed_model(model, mode)
937  assert distributed_model
938
939  # Also create an execution fuction on that distributed model.
940  if context.executing_eagerly():
941    distributed_function = _make_eager_execution_function(model, mode)
942  else:
943    distributed_function = _make_graph_execution_function(model, mode)
944
945  # We cache the distributed execution function on the model since creating
946  # distributed models and execution functions are expensive.
947  distributed_model._distributed_function = distributed_function
948  distributed_model._recompile_exec_function = False
949  return distributed_function
950
951
952def _make_graph_execution_function(model, mode):
953  """Makes function to run one step of distributed model in graph mode."""
954
955  def _per_replica_function(model):
956    f = model._make_execution_function(mode)
957    return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
958
959  strategy = model._distribution_strategy
960  with strategy.scope():
961    # Create train ops on each of the devices when we call
962    # `_per_replica_fit_function`.
963    (grouped_inputs, grouped_outputs, grouped_updates,
964     grouped_session_args) = strategy.extended.call_for_each_replica(
965         _per_replica_function, args=(get_distributed_model(model, mode),))
966
967    # Initialize the variables in the replicated model. This is necessary for
968    # multi-worker training because on some workers, initialization is not
969    # needed. This method does initialization or waiting for initialization
970    # according to the context object of distribute coordinator.
971    init_restore_or_wait_for_variables()
972
973    # Unwrap all the per device values returned from `call_for_each_replica`.
974    # Unwrapping per device values gives you a list of values that can be
975    # used to construct a new train function that is composed of update ops on
976    # all the devices over which the model is distributed.
977    (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values(
978        strategy,
979        grouped_inputs,
980        grouped_outputs,
981        grouped_updates,
982        grouped_session_args,
983        with_loss_tensor=(mode != ModeKeys.PREDICT))
984
985    return K.function(
986        all_inputs,
987        all_outputs,
988        updates=all_updates,
989        name='distributed_{}_function'.format(mode),
990        **all_session_args)
991
992
993def _make_eager_execution_function(model, mode):
994  """Makes function to run one step of distributed model eager execution."""
995  def _per_replica_function(model):
996    f = model._make_execution_function(mode)
997    return (f.inputs, f.outputs)
998
999  # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
1000  # the global one.
1001  strategy = model._distribution_strategy
1002  global_graph = K.get_graph()
1003
1004  with global_graph.as_default(), strategy.scope():
1005    # First we gather the relevant portions of the model across all replicas.
1006    # `K._scratch_graph(global_graph)` signals to Keras that it should not
1007    # lift to a separate graph when creating the per-replica functions.
1008    with K._scratch_graph(global_graph):
1009      # Create train ops on each of the devices when we call
1010      # `_per_replica_fit_function`.
1011      grouped = strategy.extended.call_for_each_replica(
1012          _per_replica_function, args=(get_distributed_model(model, mode),))
1013      grouped_inputs, grouped_outputs = grouped
1014
1015      # Unwrap all the per device values returned from `call_for_each_replica`.
1016      # Unwrapping per device values gives you a list of values that can be
1017      # used to construct a new train function that is composed of
1018      # inputs/outputs on all the devices over which the model is distributed.
1019      (all_inputs, all_outputs, _, _) = unwrap_values(
1020          strategy,
1021          grouped_inputs,
1022          grouped_outputs,
1023          with_loss_tensor=(mode != ModeKeys.PREDICT))
1024
1025    # Finally, a joint Keras function is created; this one will be created in
1026    # a separate FuncGraph.
1027    return K.function(
1028        all_inputs,
1029        all_outputs,
1030        name='eager_distributed_{}_function'.format(mode))
1031
1032
1033def _copy_weights_to_distributed_model(original_model, mode):
1034  """Copies weights from original model to distributed models."""
1035  strategy = original_model._distribution_strategy
1036  distributed_model = get_distributed_model(original_model, mode)
1037  if strategy:
1038    # Copy the weights from the original model to each of the replicated
1039    # models.
1040    orig_model_weights = original_model.get_weights()
1041    first_model = strategy.unwrap(distributed_model)[0]
1042    set_weights(strategy, first_model, orig_model_weights)
1043
1044
1045def _copy_weights_to_original_model(model, mode):
1046  """Copies weights from first distributed model back to original model."""
1047  if model._distribution_strategy and mode == ModeKeys.TRAIN:
1048    distributed_model = get_distributed_model(model, mode)
1049    updated_weights = model._distribution_strategy.unwrap(
1050        distributed_model)[0].get_weights()
1051    model.set_weights(updated_weights)
1052
1053
1054def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
1055  """Aggregates the per-replica batch-level outputs from a distributed step."""
1056  if strategy is not None and mode == ModeKeys.PREDICT:
1057    total_batch_outs = []
1058    for i in range(len(model.outputs)):
1059      num_replicas = strategy.num_replicas_in_sync
1060      nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
1061      total_batch_outs.append(
1062          concat_along_batch_dimension(nest.flatten(nested_outs)))
1063    return total_batch_outs
1064  return batch_outs
1065
1066
1067def _reset_metrics(model):
1068  if model._distribution_strategy:
1069    for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
1070      distributed_model = get_distributed_model(model, mode)
1071      if distributed_model:
1072        first_model = model._distribution_strategy.unwrap(distributed_model)[0]
1073        first_model.reset_metrics()
1074
1075
1076def get_distributed_model(model, mode):
1077  key = _generate_cache_key(mode)
1078  return model._distributed_model_cache.get(key, None)
1079
1080
1081def set_distributed_model(model, mode, distributed_model):
1082  key = _generate_cache_key(mode)
1083  model._distributed_model_cache[key] = distributed_model
1084
1085
1086def get_distributed_function(model, mode):
1087  key = _generate_cache_key(mode)
1088  return model._distributed_function_cache.get(key, None)
1089
1090
1091def set_distributed_function(model, mode, distributed_function):
1092  key = _generate_cache_key(mode)
1093  model._distributed_function_cache[key] = distributed_function
1094
1095
1096def _generate_cache_key(mode):
1097  key = hash(mode)
1098  return key
1099
1100
1101@tf_contextlib.contextmanager
1102def distributed_scope(strategy, learning_phase):
1103  with strategy.scope(), K.learning_phase_scope(learning_phase):
1104    yield
1105
1106
1107def call_replica_local_fn(fn, *args, **kwargs):
1108  """Call a function that uses replica-local variables.
1109
1110  This function correctly handles calling `fn` in a cross-replica
1111  context.
1112
1113  Arguments:
1114    fn: The function to call.
1115    *args: Positional arguments to the `fn`.
1116    **kwargs: Keyword argument to `fn`.
1117
1118  Returns:
1119    The result of calling `fn`.
1120  """
1121  # TODO(b/132666209): Remove this function when we support assign_*
1122  # for replica-local variables.
1123  strategy = None
1124  if 'strategy' in kwargs:
1125    strategy = kwargs.pop('strategy')
1126  else:
1127    if ds_context.has_strategy():
1128      strategy = ds_context.get_strategy()
1129
1130  # TODO(b/120571621): TPUStrategy does not implement replica-local variables.
1131  is_tpu = is_tpu_strategy(strategy)
1132  if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
1133    with strategy.scope():
1134      return strategy.extended.call_for_each_replica(fn, args, kwargs)
1135  return fn(*args, **kwargs)
1136
1137
1138def is_current_worker_chief():
1139  return dc_context.get_current_worker_context().is_chief
1140
1141
1142def filter_distributed_callbacks(callbacks_list, model):
1143  """Filter Callbacks based on the worker context when running multi-worker.
1144
1145  Arguments:
1146    callbacks_list: A list of `Callback` instances.
1147    model: Keras model instance.
1148
1149  Returns:
1150    The list of `Callback` instances that should be run on this worker.
1151  """
1152
1153  if not model._in_multi_worker_mode():
1154    raise ValueError(
1155        'filter_distributed_callbacks() should only be called when Keras '
1156        'is in multi worker mode.')
1157
1158  callbacks_list = callbacks_list or []
1159  if not [
1160      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
1161  ]:
1162    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
1163    # fails to (possibly with tempfile directory).
1164    logging.warning('ModelCheckpoint callback is not provided. '
1165                    'Workers will need to restart training if any fails.')
1166
1167  if callbacks_list is None or is_current_worker_chief():
1168    return callbacks_list
1169
1170  # Some Callbacks should only run on the chief worker.
1171  return [
1172      callback for callback in callbacks_list if not callback._chief_worker_only
1173  ]  # pylint: disable=protected-access
1174
1175
1176def _update_sample_weight_modes(model, mode, sample_weights):
1177  """Update sample_weight_mode of the distributed model."""
1178  if is_distributing_by_cloning(model):
1179    distributed_model = get_distributed_model(model, mode)
1180    if not distributed_model:
1181      _make_replicated_models_with_cloning(model, mode)
1182      distributed_model = get_distributed_model(model, mode)
1183    distributed_model._recompile_exec_function = any(
1184        [e.sample_weights_mismatch() for e in model._training_endpoints])
1185
1186    if sample_weights:
1187      distributed_models = flatten_per_replica_values(
1188          model._distribution_strategy, distributed_model)
1189      # sample_weights is a tuple of 1 list where the number of elements in the
1190      # list is equal to the number of replicas in sync.
1191      sample_weights = sample_weights[0]
1192      if sample_weights and None not in sample_weights:
1193        for m, sw in zip(distributed_models, sample_weights):
1194          m._update_sample_weight_modes(sample_weights=[sw])
1195
1196
1197def concat_along_batch_dimension(outputs):
1198  """Concats prediction outputs along the batch dimension."""
1199  if isinstance(outputs[0], sparse_tensor.SparseTensor):
1200    return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
1201  if isinstance(outputs[0], ragged_tensor.RaggedTensor):
1202    return ragged_concat_ops.concat(outputs, axis=0)
1203  return np.concatenate(outputs)
1204