• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import sys
23
24import six
25
26from tensorflow.python import tf2
27from tensorflow.python.data.experimental.ops import batching
28from tensorflow.python.data.experimental.ops import cardinality
29from tensorflow.python.data.experimental.ops import distribute
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import iterator_ops
32from tensorflow.python.data.ops import multi_device_iterator_ops
33from tensorflow.python.data.ops import optional_ops
34from tensorflow.python.distribute import device_util
35from tensorflow.python.distribute import distribute_lib
36from tensorflow.python.distribute import distribute_utils
37from tensorflow.python.distribute import distribution_strategy_context
38from tensorflow.python.distribute import input_ops
39from tensorflow.python.distribute import reduce_util
40from tensorflow.python.distribute import values
41from tensorflow.python.distribute.distribute_lib import InputReplicationMode
42from tensorflow.python.eager import context
43from tensorflow.python.framework import composite_tensor
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import device as tf_device
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import sparse_tensor
50from tensorflow.python.framework import tensor_shape
51from tensorflow.python.framework import tensor_util
52from tensorflow.python.framework import type_spec
53from tensorflow.python.ops import array_ops
54from tensorflow.python.ops import control_flow_ops
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops.ragged import ragged_tensor
57from tensorflow.python.platform import tf_logging as logging
58from tensorflow.python.types import distribute as distribute_types
59from tensorflow.python.util import nest
60from tensorflow.python.util.compat import collections_abc
61from tensorflow.python.util.deprecation import deprecated
62from tensorflow.python.util.tf_export import tf_export
63from tensorflow.tools.docs import doc_controls
64
65
66def get_distributed_dataset(dataset,
67                            input_workers,
68                            strategy,
69                            num_replicas_in_sync=None,
70                            input_context=None,
71                            options=None,
72                            build=True):
73  """Returns a distributed dataset from the given tf.data.Dataset instance.
74
75  This is a common function that is used by all strategies to return a
76  distributed dataset. The distributed dataset instance returned is different
77  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
78  instances returned differ from each other in the APIs supported by each of
79  them.
80
81  Args:
82    dataset: a tf.data.Dataset instance.
83    input_workers: an InputWorkers object which specifies devices on which
84        iterators should be created.
85    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
86        handle last partial batch.
87    num_replicas_in_sync: Optional integer. If this is not None, the value is
88        used to decide how to rebatch datasets into smaller batches so that
89        the total batch size for each step (across all workers and replicas)
90        adds up to `dataset`'s batch size.
91    input_context: `InputContext` for sharding. Only pass this in for between
92        graph multi-worker cases where there is only one `input_worker`. In
93        these cases, we will shard based on the `input_pipeline_id` and
94        `num_input_pipelines` in the `InputContext`.
95    options: Default is None. `tf.distribute.InputOptions` used to control
96        options on how this dataset is distributed.
97    build: whether to build underlying datasets when a DistributedDataset is
98        created. This is only useful for `ParameterServerStrategy` now.
99
100  Returns:
101    A distributed dataset instance.
102  """
103  if tf2.enabled():
104    return DistributedDataset(
105        input_workers,
106        strategy,
107        dataset,
108        num_replicas_in_sync=num_replicas_in_sync,
109        input_context=input_context,
110        build=build,
111        options=options)
112  else:
113    return DistributedDatasetV1(
114        dataset,
115        input_workers,
116        strategy,
117        num_replicas_in_sync=num_replicas_in_sync,
118        input_context=input_context,
119        options=options)
120
121
122def get_distributed_datasets_from_function(dataset_fn,
123                                           input_workers,
124                                           input_contexts,
125                                           strategy,
126                                           options=None,
127                                           build=True):
128  """Returns a distributed dataset from the given input function.
129
130  This is a common function that is used by all strategies to return a
131  distributed dataset. The distributed dataset instance returned is different
132  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
133  instances returned differ from each other in the APIs supported by each of
134  them.
135
136  Args:
137    dataset_fn: a function that returns a tf.data.Dataset instance.
138    input_workers: an InputWorkers object which specifies devices on which
139        iterators should be created.
140    input_contexts: A list of `InputContext` instances to be passed to call(s)
141        to `dataset_fn`. Length and order should match worker order in
142        `worker_device_pairs`.
143    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
144        handle last partial batch.
145    options: Default is None. `tf.distribute.InputOptions` used to control
146        options on how this dataset is distributed.
147    build: whether to build underlying datasets when a
148        `DistributedDatasetFromFunction` is created. This is only useful for
149        `ParameterServerStrategy` now.
150
151  Returns:
152    A distributed dataset instance.
153
154  Raises:
155    ValueError: if `options.experimental_replication_mode` and
156    `options.experimental_place_dataset_on_device` are not consistent
157  """
158  if (options is not None and
159      options.experimental_replication_mode != InputReplicationMode.PER_REPLICA
160      and options.experimental_place_dataset_on_device):
161    raise ValueError(
162        "When `experimental_place_dataset_on_device` is set for dataset "
163        "placement, you must also specify `PER_REPLICA` for the "
164        "replication mode")
165
166  if (options is not None and
167      options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
168      and options.experimental_fetch_to_device and
169      options.experimental_place_dataset_on_device):
170    raise ValueError(
171        "`experimental_place_dataset_on_device` can not be set to True "
172        "when experimental_fetch_to_device is True and "
173        "replication mode is set to `PER_REPLICA`")
174
175  if tf2.enabled():
176    return DistributedDatasetsFromFunction(
177        input_workers,
178        strategy,
179        input_contexts=input_contexts,
180        dataset_fn=dataset_fn,
181        options=options,
182        build=build,
183    )
184  else:
185    return DistributedDatasetsFromFunctionV1(input_workers, strategy,
186                                             input_contexts, dataset_fn,
187                                             options)
188
189
190def get_iterator_spec_from_dataset(strategy, dataset):
191  """Returns an iterator spec from dataset function.
192
193  This function constructs type spec for iterator obtained from
194  iter(dataset).
195
196  Args:
197    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
198        handle last partial batch.
199    dataset: A tf.data.Dataset instance. If using a function that returns a
200      tf.data.Dataset instance, pass dataset_fn.structured_outputs.
201
202  Returns:
203    A type_spec for iterator for dataset instance.
204
205  """
206  output_element_spec = dataset.element_spec
207  if isinstance(dataset._type_spec,  # pylint: disable=protected-access
208                (DistributedDatasetSpec,
209                 DistributedDatasetsFromFunctionSpec)):
210    iterator_type_spec = DistributedIteratorSpec(
211        strategy.extended._input_workers_with_options(  # pylint: disable=protected-access
212        ), output_element_spec,
213        strategy.extended._container_strategy(), True,  # pylint: disable=protected-access
214        None)
215  else:
216    if strategy.extended._num_gpus_per_worker:  # pylint: disable=protected-access
217      logging.warning(
218          f"{strategy.extended._num_gpus_per_worker} GPUs "  # pylint: disable=protected-access
219          "are allocated per worker. Please use DistributedDataset by "
220          "calling strategy.experimental_distribute_dataset or strategy."
221          "distribute_datasets_from_function to make best use of GPU "
222          "resources"
223      )
224    iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec)
225  return iterator_type_spec
226
227
228@tf_export("distribute.DistributedIterator", v1=[])
229class DistributedIteratorInterface(collections_abc.Iterator,
230                                   distribute_types.Iterator):
231  """An iterator over `tf.distribute.DistributedDataset`.
232
233  `tf.distribute.DistributedIterator` is the primary mechanism for enumerating
234  elements of a `tf.distribute.DistributedDataset`. It supports the Python
235  Iterator protocol, which means it can be iterated over using a for-loop or by
236  fetching individual elements explicitly via `get_next()`.
237
238  You can create a `tf.distribute.DistributedIterator` by calling `iter` on
239  a `tf.distribute.DistributedDataset` or creating a python loop over a
240  `tf.distribute.DistributedDataset`.
241
242  Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
243  on distributed input for more examples and caveats.
244  """
245
246  def get_next(self):
247    """Returns the next input from the iterator for all replicas.
248
249    Example use:
250
251    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
252    >>> dataset = tf.data.Dataset.range(100).batch(2)
253    >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
254    >>> dist_dataset_iterator = iter(dist_dataset)
255    >>> @tf.function
256    ... def one_step(input):
257    ...   return input
258    >>> step_num = 5
259    >>> for _ in range(step_num):
260    ...   strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
261    >>> strategy.experimental_local_results(dist_dataset_iterator.get_next())
262    (<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
263     <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
264
265    Returns:
266      A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains
267      the next input for all replicas.
268
269    Raises:
270      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
271    """
272    raise NotImplementedError(
273        "DistributedIterator.get_next() must be implemented in descendants.")
274
275  @property
276  def element_spec(self):
277    # pylint: disable=line-too-long
278    """The type specification of an element of `tf.distribute.DistributedIterator`.
279
280    Example usage:
281
282    >>> global_batch_size = 16
283    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
284    >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
285    >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
286    >>> distributed_iterator.element_spec
287    (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
288                    TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
289     PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
290                    TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
291
292    Returns:
293      A nested structure of `tf.TypeSpec` objects matching the structure of an
294      element of this `tf.distribute.DistributedIterator`. This returned value
295      is typically a `tf.distribute.DistributedValues` object and specifies the
296      `tf.TensorSpec` of individual components.
297    """
298    raise NotImplementedError(
299        "DistributedIterator.element_spec() must be implemented in descendants")
300
301  def get_next_as_optional(self):
302    # pylint: disable=line-too-long
303    """Returns a `tf.experimental.Optional` that contains the next value for all replicas.
304
305    If the `tf.distribute.DistributedIterator` has reached the end of the
306    sequence, the returned `tf.experimental.Optional` will have no value.
307
308    Example usage:
309
310    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
311    >>> global_batch_size = 2
312    >>> steps_per_loop = 2
313    >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
314    >>> distributed_iterator = iter(
315    ...     strategy.experimental_distribute_dataset(dataset))
316    >>> def step_fn(x):
317    ...   # train the model with inputs
318    ...   return x
319    >>> @tf.function
320    ... def train_fn(distributed_iterator):
321    ...   for _ in tf.range(steps_per_loop):
322    ...     optional_data = distributed_iterator.get_next_as_optional()
323    ...     if not optional_data.has_value():
324    ...       break
325    ...     per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),))
326    ...     tf.print(strategy.experimental_local_results(per_replica_results))
327    >>> train_fn(distributed_iterator)
328    ... # ([0 1], [2 3])
329    ... # ([4], [])
330
331    Returns:
332      An `tf.experimental.Optional` object representing the next value from the
333      `tf.distribute.DistributedIterator` (if it has one) or no value.
334    """
335    # pylint: enable=line-too-long
336    raise NotImplementedError(
337        "get_next_as_optional() not implemented in descendants")
338
339
340@tf_export("distribute.DistributedDataset", v1=[])
341class DistributedDatasetInterface(collections_abc.Iterable,
342                                  distribute_types.Iterable):
343  # pylint: disable=line-too-long
344  """Represents a dataset distributed among devices and machines.
345
346  A `tf.distribute.DistributedDataset` could be thought of as a "distributed"
347  dataset. When you use `tf.distribute` API to scale training to multiple
348  devices or machines, you also need to distribute the input data, which leads
349  to a `tf.distribute.DistributedDataset` instance, instead of a
350  `tf.data.Dataset` instance in the non-distributed case. In TF 2.x,
351  `tf.distribute.DistributedDataset` objects are Python iterables.
352
353  Note: `tf.distribute.DistributedDataset` instances are *not* of type
354  `tf.data.Dataset`. It only supports two usages we will mention below:
355  iteration and `element_spec`. We don't support any other APIs to transform or
356  inspect the dataset.
357
358  There are two APIs to create a `tf.distribute.DistributedDataset` object:
359  `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and
360  `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`.
361  *When to use which?* When you have a `tf.data.Dataset` instance, and the
362  regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance
363  with a new batch size that is equal to the global batch size divided by the
364  number of replicas in sync) and autosharding (i.e. the
365  `tf.data.experimental.AutoShardPolicy` options) work for you, use the former
366  API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance,
367  or you would like to customize the batch splitting or sharding, you can wrap
368  these logic in a `dataset_fn` and use the latter API. Both API handles
369  prefetch to device for the user. For more details and examples, follow the
370  links to the APIs.
371
372
373  There are two main usages of a `DistributedDataset` object:
374
375  1. Iterate over it to generate the input for a single device or multiple
376  devices, which is a `tf.distribute.DistributedValues` instance. To do this,
377  you can:
378
379    * use a pythonic for-loop construct:
380
381      >>> global_batch_size = 4
382      >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
383      >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
384      >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
385      >>> @tf.function
386      ... def train_step(input):
387      ...   features, labels = input
388      ...   return labels - 0.3 * features
389      >>> for x in dist_dataset:
390      ...   # train_step trains the model using the dataset elements
391      ...   loss = strategy.run(train_step, args=(x,))
392      ...   print("Loss is", loss)
393      Loss is PerReplica:{
394        0: tf.Tensor(
395      [[0.7]
396       [0.7]], shape=(2, 1), dtype=float32),
397        1: tf.Tensor(
398      [[0.7]
399       [0.7]], shape=(2, 1), dtype=float32)
400      }
401
402      Placing the loop inside a `tf.function` will give a performance boost.
403      However `break` and `return` are currently not supported if the loop is
404      placed inside a `tf.function`. We also don't support placing the loop
405      inside a `tf.function` when using
406      `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
407      `tf.distribute.experimental.TPUStrategy` with multiple workers.
408
409    * use `__iter__` to create an explicit iterator, which is of type
410      `tf.distribute.DistributedIterator`
411
412      >>> global_batch_size = 4
413      >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
414      >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
415      >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
416      >>> @tf.function
417      ... def distributed_train_step(dataset_inputs):
418      ...   def train_step(input):
419      ...     loss = tf.constant(0.1)
420      ...     return loss
421      ...   per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
422      ...   return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
423      >>> EPOCHS = 2
424      >>> STEPS = 3
425      >>> for epoch in range(EPOCHS):
426      ...   total_loss = 0.0
427      ...   num_batches = 0
428      ...   dist_dataset_iterator = iter(train_dist_dataset)
429      ...   for _ in range(STEPS):
430      ...     total_loss += distributed_train_step(next(dist_dataset_iterator))
431      ...     num_batches += 1
432      ...   average_train_loss = total_loss / num_batches
433      ...   template = ("Epoch {}, Loss: {:.4f}")
434      ...   print (template.format(epoch+1, average_train_loss))
435      Epoch 1, Loss: 0.2000
436      Epoch 2, Loss: 0.2000
437
438
439    To achieve a performance improvement, you can also wrap the `strategy.run`
440    call with a `tf.range` inside a `tf.function`. This runs multiple steps in a
441    `tf.function`. Autograph will convert it to a `tf.while_loop` on the worker.
442    However, it is less flexible comparing with running a single step inside
443    `tf.function`. For example, you cannot run things eagerly or arbitrary
444    python code within the steps.
445
446
447  2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`.
448
449    `tf.distribute.DistributedDataset` generates
450    `tf.distribute.DistributedValues` as input to the devices. If you pass the
451    input to a `tf.function` and would like to specify the shape and type of
452    each Tensor argument to the function, you can pass a `tf.TypeSpec` object to
453    the `input_signature` argument of the `tf.function`. To get the
454    `tf.TypeSpec` of the input, you can use the `element_spec` property of the
455    `tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator`
456    object.
457
458    For example:
459
460    >>> global_batch_size = 4
461    >>> epochs = 1
462    >>> steps_per_epoch = 1
463    >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
464    >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
465    >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
466    >>> @tf.function(input_signature=[dist_dataset.element_spec])
467    ... def train_step(per_replica_inputs):
468    ...   def step_fn(inputs):
469    ...     return tf.square(inputs)
470    ...   return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
471    >>> for _ in range(epochs):
472    ...   iterator = iter(dist_dataset)
473    ...   for _ in range(steps_per_epoch):
474    ...     output = train_step(next(iterator))
475    ...     print(output)
476    PerReplica:{
477      0: tf.Tensor(
478    [[4.]
479     [4.]], shape=(2, 1), dtype=float32),
480      1: tf.Tensor(
481    [[4.]
482     [4.]], shape=(2, 1), dtype=float32)
483    }
484
485
486  Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
487  on distributed input for more examples and caveats.
488  """
489
490  def __iter__(self):
491    """Creates an iterator for the `tf.distribute.DistributedDataset`.
492
493    The returned iterator implements the Python Iterator protocol.
494
495    Example usage:
496
497    >>> global_batch_size = 4
498    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
499    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
500    >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
501    >>> print(next(distributed_iterator))
502    PerReplica:{
503      0: tf.Tensor([1 2], shape=(2,), dtype=int32),
504      1: tf.Tensor([3 4], shape=(2,), dtype=int32)
505    }
506
507    Returns:
508      An `tf.distribute.DistributedIterator` instance for the given
509      `tf.distribute.DistributedDataset` object to enumerate over the
510      distributed data.
511    """
512    raise NotImplementedError("Must be implemented in descendants")
513
514  @property
515  def element_spec(self):
516    """The type specification of an element of this `tf.distribute.DistributedDataset`.
517
518    Example usage:
519
520    >>> global_batch_size = 16
521    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
522    >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
523    >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
524    >>> dist_dataset.element_spec
525    (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
526                    TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
527     PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
528                    TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
529
530    Returns:
531      A nested structure of `tf.TypeSpec` objects matching the structure of an
532      element of this `tf.distribute.DistributedDataset`. This returned value is
533      typically a `tf.distribute.DistributedValues` object and specifies the
534      `tf.TensorSpec` of individual components.
535    """
536    raise NotImplementedError(
537        "DistributedDataset.element_spec must be implemented in descendants.")
538
539  @doc_controls.do_not_generate_docs
540  def reduce(self, initial_state, reduce_func):
541    raise NotImplementedError(
542        "DistributedDataset.reduce must be implemented in descendants.")
543
544
545class InputWorkers(object):
546  """A 1-to-many mapping from input worker devices to compute devices."""
547
548  # TODO(ishark): Remove option canonicalize_devices and make all the callers
549  # pass canonicalized or raw device strings as relevant from strategy.
550  def __init__(self, worker_device_pairs, canonicalize_devices=True):
551    """Initialize an `InputWorkers` object.
552
553    Args:
554      worker_device_pairs: A sequence of pairs: `(input device, a tuple of
555        compute devices fed by that input device)`.
556      canonicalize_devices: Whether to canonicalize devices for workers fully or
557      partially. If False, it will partially canonicalize devices by removing
558      job and task.
559    """
560    self._worker_device_pairs = worker_device_pairs
561    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
562    self._canonicalize_devices = canonicalize_devices
563    if canonicalize_devices:
564      self._fed_devices = tuple(
565          tuple(device_util.canonicalize(d)
566                for d in f)
567          for _, f in self._worker_device_pairs)
568    else:
569      self._fed_devices = tuple(
570          tuple(device_util.canonicalize_without_job_and_task(d)
571                for d in f)
572          for _, f in self._worker_device_pairs)
573
574  @property
575  def num_workers(self):
576    return len(self._input_worker_devices)
577
578  @property
579  def worker_devices(self):
580    return self._input_worker_devices
581
582  def compute_devices_for_worker(self, worker_index):
583    return self._fed_devices[worker_index]
584
585  def __repr__(self):
586    devices = self.worker_devices
587    debug_repr = ",\n".join("  %d %s: %s" %
588                            (i, devices[i], self._fed_devices[i])
589                            for i in range(len(devices)))
590    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
591
592  def serialize(self):
593    return (self._worker_device_pairs, self._canonicalize_devices)
594
595  def deserialize(self, serialized):
596    return InputWorkers(serialized)
597
598
599def _get_next_as_optional(iterator, strategy, return_per_replica=False):
600  """Returns an empty dataset indicator and the next input from the iterator.
601
602  Args:
603    iterator: a DistributedIterator object.
604    strategy: the `tf.distribute.Strategy` instance.
605    return_per_replica: a boolean. If True, the returned data will be wrapped
606      with `PerReplica` structure. Otherwise it is a 2D
607      num_input_workers*num_replicas_per_worker list.
608
609  Returns:
610    A tuple (a boolean tensor indicating whether the next batch has value
611    globally, data from all replicas).
612  """
613  replicas = []
614  worker_has_values = []
615  worker_devices = []
616  with distribution_strategy_context.enter_or_assert_strategy(strategy):
617    if distribution_strategy_context.get_replica_context() is not None:
618      raise ValueError("next(iterator) should be called from outside of "
619                       "replica_fn. e.g. strategy.run(replica_fn, "
620                       "args=(next(iterator),))")
621
622  for i, worker in enumerate(iterator._input_workers.worker_devices):  # pylint: disable=protected-access
623    with ops.device(worker):
624      worker_has_value, next_element = (
625          iterator._iterators[i].get_next_as_list())  # pylint: disable=protected-access
626      # Collective all-reduce requires explicit devices for inputs.
627      with ops.device("/cpu:0"):
628        # Converting to integers for all-reduce.
629        worker_has_value = math_ops.cast(worker_has_value, dtypes.int64)
630        worker_devices.append(worker_has_value.device)
631        worker_has_values.append(worker_has_value)
632      # Make `replicas` a flat list of values across all replicas.
633      replicas.append(next_element)
634
635  if return_per_replica:
636    flattened_data = []
637    for per_worker_data in replicas:
638      flattened_data.extend(per_worker_data)
639    replicas = _create_per_replica(flattened_data, strategy)
640
641  # Run an all-reduce to see whether any worker has values.
642  # TODO(b/131423105): we should be able to short-cut the all-reduce in some
643  # cases.
644  if getattr(strategy.extended, "_support_per_replica_values", True):
645    # `reduce` expects a `PerReplica`, so we pass it one, even
646    # though it doesn't actually have a value per replica
647    worker_has_values = values.PerReplica(worker_has_values)
648    global_has_value = strategy.reduce(
649        reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
650  else:
651    assert len(worker_has_values) == 1
652    global_has_value = worker_has_values[0]
653  global_has_value = array_ops.reshape(
654      math_ops.cast(global_has_value, dtypes.bool), [])
655  return global_has_value, replicas
656
657
658def _is_statically_shaped(element_spec):
659  """Test if an iterator output is statically shaped.
660
661  For sparse and ragged tensors this only tests the batch dimension.
662
663  Args:
664    element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
665      dataset of the iterator.
666
667  Returns:
668    True if the shape is static, false otherwise.
669  """
670
671  for spec in nest.flatten(element_spec):
672    if isinstance(
673        spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
674      # For sparse or ragged tensor, we should only check the first
675      # dimension in order to get_next_as_optional. This is because
676      # when these tensors get batched by dataset only the batch dimension
677      # is set.
678      if spec.shape.rank > 0 and spec.shape.as_list()[0] is None:
679        return False
680    else:
681      for component in nest.flatten(spec._component_specs):  # pylint: disable=protected-access
682        if not component.shape.is_fully_defined():
683          return False
684  return True
685
686
687class DistributedIteratorBase(DistributedIteratorInterface):
688  """Common implementation for all input iterators."""
689
690  # pylint: disable=super-init-not-called
691  def __init__(self, input_workers, iterators, strategy,
692               enable_get_next_as_optional):
693    assert isinstance(input_workers, InputWorkers)
694    if not input_workers.worker_devices:
695      raise ValueError("Should have at least one worker for input iterator.")
696
697    self._iterators = iterators
698    self._input_workers = input_workers
699    self._strategy = strategy
700    self._enable_get_next_as_optional = enable_get_next_as_optional
701
702  def next(self):
703    return self.__next__()
704
705  def __next__(self):
706    try:
707      return self.get_next()
708    except errors.OutOfRangeError:
709      raise StopIteration
710
711  def __iter__(self):
712    return self
713
714  def get_next_as_optional(self):
715    global_has_value, replicas = _get_next_as_optional(
716        self, self._strategy, return_per_replica=True)
717
718    def return_none():
719      return optional_ops.Optional.empty(self._element_spec)
720
721    return control_flow_ops.cond(
722        global_has_value, lambda: optional_ops.Optional.from_value(replicas),
723        return_none)
724
725  def get_next(self, name=None):
726    """Returns the next input from the iterator for all replicas."""
727    if not self._enable_get_next_as_optional:
728      with distribution_strategy_context.enter_or_assert_strategy(
729          self._strategy):
730        if distribution_strategy_context.get_replica_context() is not None:
731          raise ValueError("next(iterator) should be called from outside of "
732                           "replica_fn. e.g. strategy.run(replica_fn, "
733                           "args=(next(iterator),))")
734
735      replicas = []
736      for i, worker in enumerate(self._input_workers.worker_devices):
737        if name is not None:
738          d = tf_device.DeviceSpec.from_string(worker)
739          new_name = "%s_%s_%d" % (name, d.job, d.task)
740        else:
741          new_name = None
742        with ops.device(worker):
743          # Make `replicas` a flat list of values across all replicas.
744          replicas.extend(
745              self._iterators[i].get_next_as_list_static_shapes(new_name))
746      return _create_per_replica(replicas, self._strategy)
747
748    out_of_range_replicas = []
749    def out_of_range_fn(worker_index, device):
750      """This function will throw an OutOfRange error."""
751      # As this will be only called when there is no data left, so calling
752      # get_next() will trigger an OutOfRange error.
753      data = self._iterators[worker_index].get_next(device)
754      out_of_range_replicas.append(data)
755      return data
756
757    global_has_value, replicas = _get_next_as_optional(
758        self, self._strategy, return_per_replica=False)
759    results = []
760    for i, worker in enumerate(self._input_workers.worker_devices):
761      with ops.device(worker):
762        devices = self._input_workers.compute_devices_for_worker(i)
763        for j, device in enumerate(devices):
764          with ops.device(device):
765            # pylint: disable=undefined-loop-variable
766            # pylint: disable=cell-var-from-loop
767            # It is fine for the lambda to capture variables from the loop as
768            # the lambda is executed in the loop as well.
769            result = control_flow_ops.cond(
770                global_has_value,
771                lambda: replicas[i][j],
772                lambda: out_of_range_fn(i, device),
773                strict=True,
774            )
775            # pylint: enable=cell-var-from-loop
776            # pylint: enable=undefined-loop-variable
777            results.append(result)
778    replicas = results
779
780    return _create_per_replica(replicas, self._strategy)
781
782
783class DistributedIteratorV1(DistributedIteratorBase):
784  """Input Iterator for a distributed dataset."""
785
786  # We need a private initializer method for re-initializing multidevice
787  # iterators when used with Keras training loops. If we don't reinitialize the
788  # iterator we run into memory leak issues (b/123315763).
789  @property
790  def _initializer(self):
791    init_ops = []
792    for it in self._iterators:
793      init_ops.extend(it.initialize())
794    return control_flow_ops.group(init_ops)
795
796  @deprecated(None, "Use the iterator's `initializer` property instead.")
797  def initialize(self):
798    """Initialize underlying iterators.
799
800    Returns:
801      A list of any initializer ops that should be run.
802    """
803    return self._initializer
804
805  @property
806  def initializer(self):
807    """Returns a list of ops that initialize the iterator."""
808    return self.initialize()
809
810  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
811  @property
812  def output_classes(self):
813    return self._iterators[0].output_classes
814
815  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
816  @property
817  def output_shapes(self):
818    return self._iterators[0].output_shapes
819
820  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
821  @property
822  def output_types(self):
823    return self._iterators[0].output_types
824
825  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
826  def get_iterator(self, worker):
827    for i, w in enumerate(self._input_workers.worker_devices):
828      if worker == w:
829        return self._iterators[i]
830    return None
831
832  @property
833  def element_spec(self):
834    """The type specification of an element of this iterator."""
835    return self._element_spec
836
837
838class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec):
839  """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction."""
840
841  __slots__ = [
842      "_input_workers", "_element_spec", "_strategy",
843      "_enable_get_next_as_optional", "_options",
844      "_canonicalize_devices"
845  ]
846
847  def __init__(self,
848               input_workers,
849               element_spec,
850               strategy,
851               options,
852               enable_get_next_as_optional=None):
853    # We don't want to allow deserialization of this class because we don't
854    # serialize the strategy object. Currently the only places where
855    # _deserialize is called is when we save/restore using SavedModels.
856    if isinstance(input_workers, tuple):
857      raise NotImplementedError("DistributedIteratorSpec does not have support "
858                                "for deserialization.")
859    else:
860      self._input_workers = input_workers
861      self._element_spec = element_spec
862      self._strategy = strategy
863      self._enable_get_next_as_optional = enable_get_next_as_optional
864      self._options = options
865      if self._strategy:
866        self._canonicalize_devices = getattr(self._strategy,
867                                             "_canonicalize_devices", True)
868      else:
869        self._canonicalize_devices = True
870
871  def _serialize(self):
872    # We cannot serialize the strategy object so we convert it to an id that we
873    # can use for comparison.
874    return (self._input_workers.serialize(), self._element_spec,
875            id(self._strategy), id(self._options))
876
877  def _deserialize(self):
878    raise ValueError(
879        f"Deserialization is currently unsupported for {type(self)}.")
880
881  def sanity_check_type(self, other):
882    """Returns the most specific TypeSpec compatible with `self` and `other`.
883
884    Args:
885      other: A `TypeSpec`.
886
887    Raises:
888      ValueError: If there is no TypeSpec that is compatible with both `self`
889        and `other`.
890    """
891    # pylint: disable=protected-access
892    if type(self) is not type(other):
893      raise ValueError("No TypeSpec is compatible with both %s and %s" %
894                       (self, other))
895    if self._input_workers.serialize() != other._input_workers.serialize():
896      raise ValueError("_input_workers is not compatible with both %s "
897                       "and %s" % (self, other))
898    if self._strategy is not other._strategy:
899      raise ValueError("tf.distribute strategy is not compatible with both %s "
900                       "and %s" % (self, other))
901
902
903class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec):
904  """Type specification for `DistributedIterator`."""
905
906  def __init__(self, input_workers, element_spec, strategy,
907               enable_get_next_as_optional, options):
908    super(DistributedIteratorSpec,
909          self).__init__(input_workers, element_spec, strategy, options,
910                         enable_get_next_as_optional)
911
912  @property
913  def value_type(self):
914    return DistributedIterator
915
916  # Overriding this method so that we can merge and reconstruct the spec object
917  def most_specific_compatible_type(self, other):
918    """Returns the most specific TypeSpec compatible with `self` and `other`.
919
920    Args:
921      other: A `TypeSpec`.
922
923    Raises:
924      ValueError: If there is no TypeSpec that is compatible with both `self`
925        and `other`.
926    """
927    # pylint: disable=protected-access
928    self.sanity_check_type(other)
929    element_spec = nest.map_structure(
930        lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
931        other._element_spec)
932    return DistributedIteratorSpec(self._input_workers, element_spec,
933                                   self._strategy,
934                                   self._enable_get_next_as_optional,
935                                   self._options)
936
937  @property
938  def _component_specs(self):
939    specs = []
940    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
941
942    for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
943      element_spec = nest.map_structure(
944          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
945      specs.append(
946          _SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
947                                           element_spec, self._options,
948                                           self._canonicalize_devices))
949    return specs
950
951  def _to_components(self, value):
952    return value._iterators  # pylint: disable=protected-access
953
954  def _from_components(self, components):
955    return DistributedIterator(
956        input_workers=self._input_workers,
957        iterators=None,
958        components=components,
959        element_spec=self._element_spec,
960        strategy=self._strategy,
961        enable_get_next_as_optional=self._enable_get_next_as_optional,
962        options=self._options)
963
964  @staticmethod
965  def from_value(value):
966    # pylint: disable=protected-access
967    return DistributedIteratorSpec(value._input_workers, value._element_spec,
968                                   value._strategy,
969                                   value._enable_get_next_as_optional,
970                                   value._options)
971
972  def _with_tensor_ranks_only(self):
973    element_spec = nest.map_structure(
974        lambda s: s._with_tensor_ranks_only(),  # pylint: disable=protected-access
975        self._element_spec)
976    return DistributedIteratorSpec(self._input_workers, element_spec,
977                                   self._strategy,
978                                   self._enable_get_next_as_optional,
979                                   self._options)
980
981
982class DistributedIterator(DistributedIteratorBase,
983                          composite_tensor.CompositeTensor):
984  """Input Iterator for a distributed dataset."""
985
986  def __init__(self,
987               input_workers=None,
988               iterators=None,
989               strategy=None,
990               components=None,
991               element_spec=None,
992               enable_get_next_as_optional=False,
993               options=None):
994    if input_workers is None:
995      raise ValueError("`input_workers` should be "
996                       "provided.")
997
998    error_message = ("Either `input_workers` or "
999                     "both `components` and `element_spec` need to be "
1000                     "provided.")
1001    self._options = options
1002
1003    if iterators is None:
1004      if (components is None or element_spec is None):
1005        raise ValueError(error_message)
1006      self._element_spec = element_spec
1007      self._input_workers = input_workers
1008      self._iterators = components
1009      self._strategy = strategy
1010      self._enable_get_next_as_optional = enable_get_next_as_optional
1011    else:
1012      if (components is not None and element_spec is not None):
1013        raise ValueError(error_message)
1014
1015      super(DistributedIterator,
1016            self).__init__(input_workers, iterators, strategy,
1017                           enable_get_next_as_optional)
1018
1019  @property
1020  def element_spec(self):
1021    # When partial batch handling is enabled, always set the batch dimension to
1022    # None, otherwise we just follow element_spec of the underlying dataset
1023    # (whose batch dimension may also be None). This is because with partial
1024    # batching handling we could always produce empty batches.
1025    if (self._enable_get_next_as_optional and
1026        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1027      return nest.map_structure(
1028          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1029    return self._element_spec
1030
1031  @property
1032  def _type_spec(self):
1033    # Note that we use actual element_spec instead of the rebatched-as-dynamic
1034    # one to create DistributedIteratorSpec, to be consistent with the
1035    # underlying iterators' specs.
1036    return DistributedIteratorSpec(self._input_workers, self._element_spec,
1037                                   self._strategy,
1038                                   self._enable_get_next_as_optional,
1039                                   self._options)
1040
1041
1042class _IterableInput(DistributedDatasetInterface):
1043  """Base class for iterable inputs for distribution strategies."""
1044
1045  # pylint: disable=super-init-not-called
1046  def __init__(self, input_workers):
1047    assert isinstance(input_workers, InputWorkers)
1048    self._input_workers = input_workers
1049
1050  def __iter__(self):
1051    raise NotImplementedError("must be implemented in descendants")
1052
1053  def reduce(self, initial_state, reduce_fn):
1054    """Execute a `reduce_fn` over all the elements of the input."""
1055    iterator = iter(self)
1056    has_data, data = _get_next_as_optional(
1057        iterator, self._strategy, return_per_replica=True)
1058
1059    def cond(has_data, data, state):
1060      del data, state  # Unused.
1061      return has_data
1062
1063    def loop_body(has_data, data, state):
1064      """Executes `reduce_fn` in a loop till the dataset is empty."""
1065      del has_data  # Unused.
1066      state = reduce_fn(state, data)
1067      has_data, data = _get_next_as_optional(
1068          iterator, self._strategy, return_per_replica=True)
1069      return has_data, data, state
1070
1071    has_data, data, final_state = control_flow_ops.while_loop(
1072        cond, loop_body, [has_data, data, initial_state], parallel_iterations=1)
1073    return final_state
1074
1075
1076class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec):
1077  """Type specification for `DistributedDataset."""
1078
1079  def __init__(self, input_workers, element_spec, strategy,
1080               enable_get_next_as_optional, options):
1081    super(DistributedDatasetSpec,
1082          self).__init__(input_workers, element_spec, strategy, options,
1083                         enable_get_next_as_optional)
1084
1085  @property
1086  def value_type(self):
1087    return DistributedDataset
1088
1089  # Overriding this method so that we can merge and reconstruct the spec object
1090  def most_specific_compatible_type(self, other):
1091    """Returns the most specific TypeSpec compatible with `self` and `other`.
1092
1093    Args:
1094      other: A `TypeSpec`.
1095
1096    Raises:
1097      ValueError: If there is no TypeSpec that is compatible with both `self`
1098        and `other`.
1099    """
1100    # pylint: disable=protected-access
1101    self.sanity_check_type(other)
1102    element_spec = nest.map_structure(
1103        lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
1104        other._element_spec)
1105    return DistributedDatasetSpec(self._input_workers, element_spec,
1106                                  self._strategy,
1107                                  self._enable_get_next_as_optional,
1108                                  self._options)
1109
1110  @property
1111  def _component_specs(self):
1112    specs = []
1113    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
1114
1115    for i, _ in enumerate(worker_device_pairs):
1116      element_spec = nest.map_structure(
1117          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
1118      specs.append(dataset_ops.DatasetSpec(element_spec))
1119    return specs
1120
1121  def _to_components(self, value):
1122    return value._cloned_datasets  # pylint: disable=protected-access
1123
1124  def _from_components(self, components):
1125    return DistributedDataset(
1126        input_workers=self._input_workers,
1127        strategy=self._strategy,
1128        components=components,
1129        element_spec=self._element_spec,
1130        enable_get_next_as_optional=self._enable_get_next_as_optional,
1131        options=self._options)
1132
1133  @staticmethod
1134  def from_value(value):
1135    # pylint: disable=protected-access
1136    return DistributedDatasetSpec(value._input_workers, value._element_spec,
1137                                  value._strategy,
1138                                  value._enable_get_next_as_optional,
1139                                  value._options)
1140
1141
1142class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
1143  """Distributed dataset that supports prefetching to multiple devices."""
1144
1145  def __init__(self,
1146               input_workers,
1147               strategy,
1148               dataset=None,
1149               num_replicas_in_sync=None,
1150               input_context=None,
1151               components=None,
1152               element_spec=None,
1153               enable_get_next_as_optional=None,
1154               build=True,
1155               options=None):
1156    """Distribute the dataset on all workers.
1157
1158    If `num_replicas_in_sync` is not None, we split each batch of the dataset
1159    into `num_replicas_in_sync` smaller batches, to be distributed among that
1160    worker's replicas, so that the batch size for a global step (across all
1161    workers and replicas) is as expected.
1162
1163    Args:
1164      input_workers: an `InputWorkers` object.
1165      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1166        handle last partial batch.
1167      dataset: `tf.data.Dataset` that will be used as the input source. Either
1168        dataset or components field should be passed when constructing
1169        DistributedDataset. Use this when contructing DistributedDataset from a
1170        new `tf.data.Dataset`. Use components when constructing using
1171        DistributedDatasetSpec.
1172      num_replicas_in_sync: Optional integer. If this is not None, the value
1173        is used to decide how to rebatch datasets into smaller batches so that
1174        the total batch size for each step (across all workers and replicas)
1175        adds up to `dataset`'s batch size.
1176      input_context: `InputContext` for sharding. Only pass this in for between
1177        graph multi-worker cases where there is only one `input_worker`. In
1178        these cases, we will shard based on the `input_pipeline_id` and
1179        `num_input_pipelines` in the `InputContext`.
1180      components: datasets when DistributedDataset is constructed from
1181        DistributedDatasetSpec. Either field dataset or components should be
1182        passed.
1183      element_spec: element spec for DistributedDataset when constructing from
1184        DistributedDatasetSpec. This will be used to set the element_spec for
1185        DistributedDataset and verified against element_spec from components.
1186      enable_get_next_as_optional: this is required when components is passed
1187        instead of dataset.
1188      build: whether to build underlying datasets when this object is created.
1189        This is only useful for `ParameterServerStrategy` now.
1190      options: `tf.distribute.InputOptions` used to control options on how this
1191        dataset is distributed.
1192    """
1193    super(DistributedDataset, self).__init__(input_workers=input_workers)
1194    if input_workers is None or strategy is None:
1195      raise ValueError("input_workers and strategy are required arguments")
1196    if dataset is not None and components is not None:
1197      raise ValueError("Only one of dataset or components should be present")
1198    if dataset is None and components is None:
1199      raise ValueError("At least one of dataset or components should be passed")
1200
1201    self._input_workers = input_workers
1202    self._strategy = strategy
1203    self._options = options
1204    self._input_context = input_context
1205    self._num_replicas_in_sync = num_replicas_in_sync
1206
1207    if dataset is not None:
1208      self._original_dataset = dataset
1209      self._built = False
1210      if build:
1211        self.build()
1212    else:
1213      if not build:
1214        raise ValueError(
1215            "When constructing DistributedDataset with components, build "
1216            "should not be False. This is an internal error. Please file a "
1217            "bug.")
1218      if enable_get_next_as_optional is None:
1219        raise ValueError(
1220            "When constructing DistributedDataset with components, " +
1221            "enable_get_next_as_optional should also be passed")
1222      self._cloned_datasets = components
1223      self._enable_get_next_as_optional = enable_get_next_as_optional
1224
1225      assert element_spec is not None
1226      if element_spec != _create_distributed_tensor_spec(
1227          self._strategy, self._cloned_datasets[0].element_spec):
1228        raise ValueError("Mismatched element_spec from the passed components")
1229      self._element_spec = element_spec
1230
1231      self._built = True
1232
1233  def build(self, dataset_to_replace=None):
1234    assert not self._built
1235    dataset = dataset_to_replace or self._original_dataset
1236    self._create_cloned_datasets_from_dataset(dataset, self._input_context,
1237                                              self._input_workers,
1238                                              self._strategy,
1239                                              self._num_replicas_in_sync)
1240    self._element_spec = _create_distributed_tensor_spec(
1241        self._strategy, self._cloned_datasets[0].element_spec)
1242    self._built = True
1243
1244  def _create_cloned_datasets_from_dataset(self, dataset, input_context,
1245                                           input_workers, strategy,
1246                                           num_replicas_in_sync):
1247    # We clone and shard the dataset on each worker. The current setup tries to
1248    # shard the dataset by files if possible so that each worker sees a
1249    # different subset of files. If that is not possible, will attempt to shard
1250    # the final input such that each worker will run the entire preprocessing
1251    # pipeline and only receive its own shard of the dataset.
1252
1253    # Additionally, we rebatch the dataset on each worker into
1254    # `num_replicas_in_sync` smaller batches to be distributed among that
1255    # worker's replicas, so that the batch size for a global step (across all
1256    # workers and replicas) adds up to the original dataset's batch size.
1257    if num_replicas_in_sync is not None:
1258      num_workers = input_context.num_input_pipelines if input_context else len(
1259          input_workers.worker_devices)
1260      rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
1261                                         num_replicas_in_sync)
1262    else:
1263      rebatch_fn = None
1264    self._cloned_datasets = []
1265    if input_context:
1266      # Between-graph where we rely on the input_context for sharding
1267      assert input_workers.num_workers == 1
1268      if rebatch_fn is not None:
1269        dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
1270      dataset = input_ops.auto_shard_dataset(dataset,
1271                                             input_context.num_input_pipelines,
1272                                             input_context.input_pipeline_id,
1273                                             num_replicas_in_sync)
1274      self._cloned_datasets.append(dataset)
1275    else:
1276      replicated_ds = distribute.replicate(dataset,
1277                                           input_workers.worker_devices)
1278      for i, worker in enumerate(input_workers.worker_devices):
1279        with ops.device(worker):
1280          cloned_dataset = replicated_ds[worker]
1281          if rebatch_fn is not None:
1282            cloned_dataset = rebatch_fn(cloned_dataset, i)
1283          cloned_dataset = input_ops.auto_shard_dataset(
1284              cloned_dataset, len(input_workers.worker_devices), i,
1285              num_replicas_in_sync)
1286          self._cloned_datasets.append(cloned_dataset)
1287
1288    self._enable_get_next_as_optional = _enable_get_next_as_optional(
1289        strategy, dataset)
1290
1291  def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
1292    """Returns a callable that rebatches the input dataset.
1293
1294    Args:
1295      dataset: A `tf.data.Dataset` representing the dataset to be distributed.
1296      num_workers: An integer representing the number of workers to distribute
1297        `dataset` among.
1298      num_replicas_in_sync: An integer representing the number of replicas in
1299        sync across all workers.
1300    """
1301    if num_replicas_in_sync % num_workers:
1302      raise ValueError(
1303          "tf.distribute expects every worker to have the same number of "
1304          "replicas. However, encountered `num_replicas_in_sync` ({}) that "
1305          "cannot be divided by `num_workers` ({})".format(
1306              num_replicas_in_sync, num_workers))
1307
1308    num_replicas_per_worker = num_replicas_in_sync // num_workers
1309    with ops.colocate_with(dataset._variant_tensor):  # pylint: disable=protected-access
1310      batch_size = distribute.compute_batch_size(dataset)
1311
1312    def rebatch_fn(dataset, worker_index):
1313      try:
1314        # pylint: disable=protected-access
1315        def apply_rebatch():
1316          batch_sizes = distribute.batch_sizes_for_worker(
1317              batch_size, num_workers, num_replicas_per_worker, worker_index)
1318          return distribute._RebatchDataset(
1319              dataset, batch_sizes).prefetch(num_replicas_per_worker)
1320
1321        def apply_legacy_rebatch():
1322          return distribute._LegacyRebatchDataset(
1323              dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
1324
1325        with ops.colocate_with(dataset._variant_tensor):
1326          return control_flow_ops.cond(
1327              math_ops.not_equal(batch_size, -1),
1328              true_fn=apply_rebatch,
1329              false_fn=apply_legacy_rebatch)
1330      except errors.InvalidArgumentError as e:
1331        if "without encountering a batch" in str(e):
1332          six.reraise(
1333              ValueError,
1334              ValueError(
1335                  "Call the `batch` method on the input Dataset in order to be "
1336                  "able to split your input across {} replicas.\n Please see "
1337                  "the tf.distribute.Strategy guide. {}".format(
1338                      num_replicas_in_sync, e)),
1339              sys.exc_info()[2])
1340        else:
1341          raise
1342
1343    return rebatch_fn
1344
1345  def __iter__(self):
1346    if not (context.executing_eagerly() or
1347            ops.get_default_graph().building_function):
1348      raise RuntimeError("__iter__() is only supported inside of tf.function "
1349                         "or when eager execution is enabled.")
1350    if not self._built:
1351      raise ValueError("To use this dataset, you need to pass this dataset to "
1352                       "ClusterCoordinator.create_per_worker_dataset.")
1353
1354    # This is an optional flag that can be used to turn off using
1355    # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
1356    # as a stop gap solution that will allow us to roll out this change.
1357    enable_legacy_iterators = getattr(self._strategy,
1358                                      "_enable_legacy_iterators", False)
1359
1360    canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
1361                                   True)
1362
1363    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
1364                                                    self._input_workers,
1365                                                    enable_legacy_iterators,
1366                                                    self._options,
1367                                                    canonicalize_devices)
1368    if enable_legacy_iterators:
1369      iterator = DistributedIteratorV1(
1370          self._input_workers,
1371          worker_iterators,
1372          self._strategy,
1373          enable_get_next_as_optional=self._enable_get_next_as_optional)
1374    else:
1375      iterator = DistributedIterator(
1376          self._input_workers,
1377          worker_iterators,
1378          self._strategy,
1379          enable_get_next_as_optional=self._enable_get_next_as_optional,
1380          options=self._options)
1381    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1382
1383    # When async eager is enabled, sometimes the iterator may not finish
1384    # initialization before passing to a multi device function, add a sync point
1385    # here to make sure all underlying iterators are initialized.
1386    if context.executing_eagerly():
1387      context.async_wait()
1388
1389    return iterator
1390
1391  @property
1392  def element_spec(self):
1393    """The type specification of an element of this dataset."""
1394    # When partial batch handling is enabled, always set the batch dimension to
1395    # None, otherwise we just follow element_spec of the underlying dataset
1396    # (whose batch dimension may also be None). This is because with partial
1397    # batching handling we could always produce empty batches.
1398    if (self._enable_get_next_as_optional and
1399        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1400      return nest.map_structure(
1401          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1402    return self._element_spec
1403
1404  @property
1405  def _type_spec(self):
1406    return DistributedDatasetSpec(self._input_workers, self._element_spec,
1407                                  self._strategy,
1408                                  self._enable_get_next_as_optional,
1409                                  self._options)
1410
1411
1412class DistributedDatasetV1(DistributedDataset):
1413  """Distributed dataset that supports prefetching to multiple devices."""
1414
1415  def __init__(self,
1416               dataset,
1417               input_workers,
1418               strategy,
1419               num_replicas_in_sync=None,
1420               input_context=None,
1421               options=None):
1422    self._input_workers = input_workers
1423    super(DistributedDatasetV1, self).__init__(
1424        input_workers,
1425        strategy,
1426        dataset,
1427        num_replicas_in_sync=num_replicas_in_sync,
1428        input_context=input_context,
1429        options=options)
1430
1431  def make_one_shot_iterator(self):
1432    """Get a one time use iterator for DistributedDatasetV1.
1433
1434    Note: This API is deprecated. Please use `for ... in dataset:` to iterate
1435    over the dataset or `iter` to create an iterator.
1436
1437    Returns:
1438      A DistributedIteratorV1 instance.
1439    """
1440    return self._make_one_shot_iterator()
1441
1442  def _make_one_shot_iterator(self):
1443    """Get an iterator for DistributedDatasetV1."""
1444    # Graph mode with one shot iterator is disabled because we have to call
1445    # `initialize` on the iterator which is only required if we are using a
1446    # tf.distribute strategy.
1447    if not context.executing_eagerly():
1448      raise ValueError("Cannot create a one shot iterator. Please use "
1449                       "`make_initializable_iterator()` instead.")
1450    return self._get_iterator()
1451
1452  def make_initializable_iterator(self):
1453    """Get an initializable iterator for DistributedDatasetV1.
1454
1455    Note: This API is deprecated. Please use
1456    `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
1457    initializable iterator.
1458
1459    Returns:
1460      A DistributedIteratorV1 instance.
1461    """
1462    return self._make_initializable_iterator()
1463
1464  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=unused-argument
1465    """Get an initializable iterator for DistributedDatasetV1."""
1466    # Eager mode generates already initialized iterators. Hence we cannot create
1467    # an initializable iterator.
1468    if context.executing_eagerly():
1469      raise ValueError("Cannot create initializable iterator in Eager mode. "
1470                       "Please use `iter()` instead.")
1471    return self._get_iterator()
1472
1473  def _get_iterator(self):
1474    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
1475                                                    self._input_workers, True,
1476                                                    self._options)
1477    iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
1478                                     self._strategy,
1479                                     self._enable_get_next_as_optional)
1480    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
1481
1482    # When async eager is enabled, sometimes the iterator may not finish
1483    # initialization before passing to a multi device function, add a sync point
1484    # here to make sure all underlying iterators are initialized.
1485    if context.executing_eagerly():
1486      context.async_wait()
1487
1488    return iterator
1489
1490  def __iter__(self):
1491    if (ops.executing_eagerly_outside_functions() or
1492        ops.get_default_graph().building_function):
1493      return self._get_iterator()
1494
1495    raise RuntimeError("__iter__() is only supported inside of tf.function "
1496                       "or when eager execution is enabled.")
1497
1498
1499class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec):
1500  """Type specification for `DistributedDatasetsFromFunction."""
1501
1502  def __init__(self, input_workers, element_spec, strategy, options):
1503    super(DistributedDatasetsFromFunctionSpec,
1504          self).__init__(input_workers, element_spec, strategy, options)
1505
1506  @property
1507  def value_type(self):
1508    return DistributedDatasetsFromFunction
1509
1510  @property
1511  def _component_specs(self):
1512    specs = []
1513    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
1514
1515    for i, _ in enumerate(worker_device_pairs):
1516      element_spec = nest.map_structure(
1517          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
1518      specs.append(dataset_ops.DatasetSpec(element_spec))
1519    return specs
1520
1521  # Overriding this method so that we can merge and reconstruct the spec object
1522  def most_specific_compatible_type(self, other):
1523    """Returns the most specific TypeSpec compatible with `self` and `other`.
1524
1525    Args:
1526      other: A `TypeSpec`.
1527
1528    Raises:
1529      ValueError: If there is no TypeSpec that is compatible with both `self`
1530        and `other`.
1531    """
1532    # pylint: disable=protected-access
1533    self.sanity_check_type(other)
1534    element_spec = nest.map_structure(
1535        lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
1536        other._element_spec)  # pylint: disable=protected-access
1537    return DistributedDatasetsFromFunctionSpec(self._input_workers,
1538                                               element_spec, self._strategy,
1539                                               self._options)
1540
1541  def _to_components(self, value):
1542    return value._datasets  # pylint: disable=protected-access
1543
1544  def _from_components(self, components):
1545    return DistributedDatasetsFromFunction(
1546        input_workers=self._input_workers,
1547        strategy=self._strategy,
1548        components=components,
1549        element_spec=self._element_spec,
1550        options=self._options)
1551
1552  @staticmethod
1553  def from_value(value):
1554    # pylint: disable=protected-access
1555    return DistributedDatasetsFromFunctionSpec(
1556        input_workers=value._input_workers,
1557        element_spec=value._element_spec,
1558        strategy=value._strategy,
1559        options=value._options)
1560
1561
1562# TODO(priyag): Add other replication modes.
1563class DistributedDatasetsFromFunction(_IterableInput,
1564                                      composite_tensor.CompositeTensor):
1565  """Inputs created from dataset function."""
1566
1567  def __init__(self,
1568               input_workers,
1569               strategy,
1570               input_contexts=None,
1571               dataset_fn=None,
1572               options=None,
1573               components=None,
1574               element_spec=None,
1575               build=True):
1576    """Makes an iterable from datasets created by the given function.
1577
1578    Args:
1579      input_workers: an `InputWorkers` object.
1580      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1581        handle last partial batch.
1582      input_contexts: A list of `InputContext` instances to be passed to call(s)
1583        to `dataset_fn`. Length and order should match worker order in
1584        `worker_device_pairs`.
1585      dataset_fn: A function that returns a `Dataset` given an `InputContext`.
1586        Either dataset_fn or components should be passed to construct
1587        DistributedDatasetsFromFunction. Use this when constructing
1588        DistributedDataset using a function. Use components when constructing
1589        using DistributedDatasetsFromFunctionSpec.
1590      options: `tf.distribute.InputOptions` used to control options on how this
1591        dataset is distributed.
1592      components: datasets when DistributedDatasetsFromFunction is constructed
1593        from DistributedDatasetsFromFunctionSpec. Only one of dataset or
1594        components should be passed.
1595      element_spec: element spec for DistributedDataset when constructing from
1596        DistributedDatasetSpec. This will be used to set the element_spec for
1597        DistributedDatasetsFromFunctionSpec and verified against element_spec
1598        from components.
1599      build: whether to build underlying datasets when this object is created.
1600        This is only useful for `ParameterServerStrategy` now.
1601    """
1602    super(DistributedDatasetsFromFunction, self).__init__(
1603        input_workers=input_workers)
1604    self._input_workers = input_workers
1605    self._strategy = strategy
1606    self._options = options
1607    if dataset_fn is not None and components is not None:
1608      raise ValueError("Only one of dataset_fn or components should be set")
1609    if dataset_fn is None and components is None:
1610      raise ValueError("At least one of dataset_fn or components should be set")
1611
1612    if dataset_fn is not None:
1613      if input_workers.num_workers != len(input_contexts):
1614        raise ValueError(
1615            "Number of input workers (%d) is not same as number of "
1616            "input_contexts (%d)" %
1617            (input_workers.num_workers, len(input_contexts)))
1618      self._input_contexts = input_contexts
1619      self._dataset_fn = dataset_fn
1620      self._built = False
1621      if build:
1622        self.build()
1623    else:
1624      if element_spec is None:
1625        raise ValueError(
1626            "element_spec should also be passed when passing components")
1627      if not build:
1628        raise ValueError(
1629            "When constructing DistributedDatasetFromFunction with components, "
1630            "build should not be False. This is an internal error. Please file "
1631            "a bug.")
1632      self._element_spec = element_spec
1633      self._datasets = components
1634      self._built = True
1635      self._enable_get_next_as_optional = _enable_get_next_as_optional(
1636          self._strategy, self._datasets[0])
1637
1638  def build(self):
1639    assert not self._built
1640    self._datasets, element_spec = (
1641        _create_datasets_from_function_with_input_context(
1642            self._input_contexts, self._input_workers, self._dataset_fn))
1643    self._element_spec = _create_distributed_tensor_spec(
1644        self._strategy, element_spec)
1645    self._enable_get_next_as_optional = _enable_get_next_as_optional(
1646        self._strategy, self._datasets[0])
1647    self._built = True
1648
1649  def __iter__(self):
1650    if not (ops.executing_eagerly_outside_functions() or
1651            ops.get_default_graph().building_function):
1652      raise RuntimeError("__iter__() is only supported inside of tf.function "
1653                         "or when eager execution is enabled.")
1654
1655    if not self._built:
1656      raise ValueError("You need to use this dataset in "
1657                       "ClusterCoordinator.create_per_worker_dataset.")
1658
1659    # This is an optional flag that can be used to turn off using
1660    # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
1661    # as a stop gap solution that will allow us to roll out this change.
1662    enable_legacy_iterators = getattr(self._strategy,
1663                                      "_enable_legacy_iterators", False)
1664    canonicalize_devices = getattr(self._strategy, "_canonicalize_devices",
1665                                   True)
1666
1667    iterators = _create_iterators_per_worker(self._datasets,
1668                                             self._input_workers,
1669                                             enable_legacy_iterators,
1670                                             self._options,
1671                                             canonicalize_devices)
1672    if enable_legacy_iterators:
1673      iterator = DistributedIteratorV1(
1674          self._input_workers,
1675          iterators,
1676          self._strategy,
1677          enable_get_next_as_optional=self._enable_get_next_as_optional)
1678    else:
1679      iterator = DistributedIterator(
1680          input_workers=self._input_workers,
1681          iterators=iterators,
1682          strategy=self._strategy,
1683          enable_get_next_as_optional=self._enable_get_next_as_optional,
1684          options=self._options)
1685    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1686
1687    # When async eager is enabled, sometimes the iterator may not finish
1688    # initialization before passing to a multi device function, add a sync
1689    # point here to make sure all underlying iterators are initialized.
1690    if context.executing_eagerly():
1691      context.async_wait()
1692
1693    return iterator
1694
1695  @property
1696  def element_spec(self):
1697    """The type specification of an element of this dataset."""
1698    # When partial batch handling is enabled, always set the batch dimension to
1699    # None, otherwise we just follow element_spec of the underlying dataset
1700    # (whose batch dimension may also be None). This is because with partial
1701    # batching handling we could always produce empty batches.
1702    if (self._enable_get_next_as_optional and
1703        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1704      return nest.map_structure(
1705          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1706    return self._element_spec
1707
1708  @property
1709  def _type_spec(self):
1710    return DistributedDatasetsFromFunctionSpec(self._input_workers,
1711                                               self._element_spec,
1712                                               self._strategy, self._options)
1713
1714
1715class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
1716  """Inputs created from dataset function."""
1717
1718  def _make_initializable_iterator(self, shared_name=None):
1719    """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
1720    del shared_name  # Unused
1721    # Eager mode generates already initialized iterators. Hence we cannot create
1722    # an initializable iterator.
1723    if context.executing_eagerly():
1724      raise ValueError("Cannot create initializable iterator in Eager mode. "
1725                       "Please use `iter()` instead.")
1726    return self._get_iterator()
1727
1728  def _make_one_shot_iterator(self):
1729    """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
1730    # Graph mode with one shot iterator is disabled because we have to call
1731    # `initialize` on the iterator which is only required if we are using a
1732    # tf.distribute strategy.
1733    if not context.executing_eagerly():
1734      raise ValueError("Cannot create a one shot iterator. Please use "
1735                       "`make_initializable_iterator()` instead.")
1736    return self._get_iterator()
1737
1738  def _get_iterator(self):
1739    iterators = _create_iterators_per_worker(self._datasets,
1740                                             self._input_workers, True,
1741                                             self._options)
1742    iterator = DistributedIteratorV1(self._input_workers, iterators,
1743                                     self._strategy,
1744                                     self._enable_get_next_as_optional)
1745    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1746
1747    # When async eager is enabled, sometimes the iterator may not finish
1748    # initialization before passing to a multi device function, add a sync point
1749    # here to make sure all underlying iterators are initialized.
1750    if context.executing_eagerly():
1751      context.async_wait()
1752
1753    return iterator
1754
1755  def __iter__(self):
1756    if (ops.executing_eagerly_outside_functions() or
1757        ops.get_default_graph().building_function):
1758      return self._get_iterator()
1759
1760    raise RuntimeError("__iter__() is only supported inside of tf.function "
1761                       "or when eager execution is enabled.")
1762
1763
1764# TODO(anjalisridhar): This class will be soon removed in favor of newer
1765# APIs.
1766class InputFunctionIterator(DistributedIteratorV1):
1767  """Iterator created from input function."""
1768
1769  def __init__(self, input_fn, input_workers, input_contexts, strategy):
1770    """Make an iterator for input provided via an input function.
1771
1772    Currently implements PER_WORKER mode, in which the `input_fn` is called
1773    once on each worker.
1774
1775    TODO(priyag): Add other replication modes.
1776
1777    Args:
1778      input_fn: Input function that returns a `tf.data.Dataset` object.
1779      input_workers: an `InputWorkers` object.
1780      input_contexts: A list of `InputContext` instances to be passed to call(s)
1781        to `input_fn`. Length and order should match worker order in
1782        `worker_device_pairs`.
1783      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1784        handle last partial batch.
1785    """
1786    assert isinstance(input_workers, InputWorkers)
1787    if input_workers.num_workers != len(input_contexts):
1788      raise ValueError(
1789          "Number of input workers (%d) is not same as number of "
1790          "input_contexts (%d)" %
1791          (input_workers.num_workers, len(input_contexts)))
1792
1793    iterators = []
1794    for i, ctx in enumerate(input_contexts):
1795      worker = input_workers.worker_devices[i]
1796      with ops.device(worker):
1797        result = input_fn(ctx)
1798        devices = input_workers.compute_devices_for_worker(i)
1799        if isinstance(result, dataset_ops.DatasetV2):
1800          iterator = _SingleWorkerDatasetIterator(result, worker, devices)
1801        elif callable(result):
1802          iterator = _SingleWorkerCallableIterator(result, worker, devices)
1803        else:
1804          raise ValueError(
1805              "input_fn must return a tf.data.Dataset or a callable.")
1806        iterators.append(iterator)
1807
1808    super(InputFunctionIterator, self).__init__(
1809        input_workers, iterators, strategy, enable_get_next_as_optional=False)
1810    self._enable_get_next_as_optional = False
1811
1812
1813# TODO(anjalisridhar): This class will soon be removed and users should move
1814# to using DistributedIterator.
1815class DatasetIterator(DistributedIteratorV1):
1816  """Iterator created from input dataset."""
1817
1818  def __init__(self,
1819               dataset,
1820               input_workers,
1821               strategy,
1822               num_replicas_in_sync=None,
1823               input_context=None):
1824    """Make an iterator for the dataset on given devices.
1825
1826    If `num_replicas_in_sync` is not None, we split each batch of the dataset
1827    into `num_replicas_in_sync` smaller batches, to be distributed among that
1828    worker's replicas, so that the batch size for a global step (across all
1829    workers and replicas) is as expected.
1830
1831    Args:
1832      dataset: `tf.data.Dataset` that will be used as the input source.
1833      input_workers: an `InputWorkers` object.
1834      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1835        handle last partial batch.
1836      num_replicas_in_sync: Optional integer. If this is not None, the value is
1837        used to decide how to rebatch datasets into smaller batches so that the
1838        total batch size for each step (across all workers and replicas) adds up
1839        to `dataset`'s batch size.
1840      input_context: `InputContext` for sharding. Only pass this in for between
1841        graph multi-worker cases where there is only one `input_worker`. In
1842        these cases, we will shard based on the `input_pipeline_id` and
1843        `num_input_pipelines` in the `InputContext`.
1844    """
1845    dist_dataset = DistributedDatasetV1(
1846        dataset,
1847        input_workers,
1848        strategy,
1849        num_replicas_in_sync=num_replicas_in_sync,
1850        input_context=input_context)
1851    worker_iterators = _create_iterators_per_worker(
1852        dist_dataset._cloned_datasets, input_workers, True)  # pylint: disable=protected-access
1853    super(DatasetIterator,
1854          self).__init__(input_workers, worker_iterators, strategy,
1855                         dist_dataset._enable_get_next_as_optional)  # pylint: disable=protected-access
1856    self._element_spec = dist_dataset.element_spec
1857
1858
1859def _dummy_tensor_fn(value_structure):
1860  """A function to create dummy tensors from `value_structure`."""
1861
1862  def create_dummy_tensor(spec):
1863    """Create a dummy tensor with possible batch dimensions set to 0."""
1864    if hasattr(spec, "_create_empty_value"):
1865      # Type spec may overwrite default dummy values behavior by declaring the
1866      # `_create_empty_value(self)` method. This method must return a value
1867      # compatible with the type spec with batch dimensions set to 0 or fail if
1868      # such a value does not exist. This allows a composite tensor to customize
1869      # dummy values creation as, in general, its dummy value is not composed
1870      # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is
1871      # never allowed to be empty). See b/183969859 for more discussions.
1872      # TODO(b/186079336): reconsider CompositeTensor support.
1873      return spec._create_empty_value()  # pylint: disable=protected-access
1874
1875    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1876      # Splice out the ragged dimensions.
1877      # pylint: disable=protected-access
1878      feature_shape = spec._shape[:1].concatenate(
1879          spec._shape[(1 + spec._ragged_rank):])
1880      feature_type = spec._dtype
1881      # pylint: enable=protected-access
1882    else:
1883      feature_shape = spec.shape
1884      feature_type = spec.dtype
1885    # Ideally we should set the batch dimension to 0, however as in
1886    # DistributionStrategy we don't know the batch dimension, we try to
1887    # guess it as much as possible. If the feature has unknown dimensions, we
1888    # will set them to 0. If the feature shape is already static, we guess the
1889    # first dimension as batch dimension and set it to 0.
1890    dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
1891            if feature_shape else [])
1892    if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
1893                 feature_shape.is_fully_defined()):
1894      dims[0] = tensor_shape.Dimension(0)
1895
1896    if isinstance(spec, sparse_tensor.SparseTensorSpec):
1897      return sparse_tensor.SparseTensor(
1898          values=array_ops.zeros(0, feature_type),
1899          indices=array_ops.zeros((0, len(dims)), dtypes.int64),
1900          dense_shape=dims)
1901
1902    # Create the dummy tensor.
1903    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
1904    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1905      # Reinsert the ragged dimensions with size 0.
1906      # pylint: disable=protected-access
1907      row_splits = array_ops.zeros(1, spec._row_splits_dtype)
1908      dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
1909          dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
1910      # pylint: enable=protected-access
1911    return dummy_tensor
1912
1913  return nest.map_structure(create_dummy_tensor, value_structure)
1914
1915
1916def _recover_shape_fn(data, value_structure):
1917  """Recover the shape of `data` the same as shape of `value_structure`."""
1918
1919  flattened_data = nest.flatten(data)
1920  for i, spec in enumerate(nest.flatten(value_structure)):
1921    for target, source in zip(
1922        nest.flatten(flattened_data[i], expand_composites=True),
1923        nest.flatten(spec, expand_composites=True)):
1924      target.set_shape(source.shape)
1925    # `SparseTensor` shape is not determined by the shape of its component
1926    # tensors. Rather, its shape depends on a tensor's values.
1927    if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape:
1928      dense_shape = spec.shape
1929      with ops.device(flattened_data[i].op.device):
1930        # For partially defined shapes, fill in missing values from tensor.
1931        if not dense_shape.is_fully_defined():
1932          dense_shape = array_ops.stack([
1933              flattened_data[i].dense_shape[j] if dim is None else dim
1934              for j, dim in enumerate(dense_shape.as_list())
1935          ])
1936        flattened_data[i] = sparse_tensor.SparseTensor(
1937            indices=flattened_data[i].indices,
1938            values=flattened_data[i].values,
1939            dense_shape=dense_shape)
1940  data = nest.pack_sequence_as(data, flattened_data)
1941  return data
1942
1943
1944class _SingleWorkerDatasetIteratorBase(object):
1945  """Iterator for a single `tf.data.Dataset`."""
1946
1947  def __init__(self, dataset, worker, devices, options=None):
1948    """Create iterator for the `dataset` to fetch data to worker's `devices` .
1949
1950    A `MultiDeviceIterator`  or `OwnedMultiDeviceIterator` is used to prefetch
1951    input to the devices on the given worker.
1952
1953    Args:
1954      dataset: A `tf.data.Dataset` instance.
1955      worker: Worker on which ops should be created.
1956      devices: Distribute data from `dataset` to these devices.
1957      options: options.
1958    """
1959    self._dataset = dataset
1960    self._worker = worker
1961    self._devices = devices
1962    self._element_spec = dataset.element_spec
1963    self._options = options
1964    self._make_iterator()
1965
1966  def _make_iterator(self):
1967    raise NotImplementedError("must be implemented in descendants")
1968
1969  def _format_data_list_with_options(self, data_list):
1970    """Change the data in to a list type if required.
1971
1972    The OwnedMultiDeviceIterator returns the list data type,
1973    while the PER_REPLICA iterator (when used with prefetch disabled)
1974    returns without the enclosed list. This is to fix the inconsistency.
1975    Args:
1976      data_list: data_list
1977    Returns:
1978      list
1979    """
1980    if (self._options and self._options.experimental_replication_mode ==
1981        InputReplicationMode.PER_REPLICA and
1982        not self._options.experimental_fetch_to_device):
1983      return [data_list]
1984    else:
1985      return data_list
1986
1987  def get_next(self, device, name=None):
1988    """Get next element for the given device."""
1989    del name
1990    with ops.device(self._worker):
1991      if _should_use_multi_device_iterator(self._options):
1992        return self._iterator.get_next(device)
1993      else:
1994        return self._iterator.get_next()
1995
1996  def get_next_as_list_static_shapes(self, name=None):
1997    """Get next element from the underlying iterator.
1998
1999    Runs the iterator get_next() within a device scope. Since this doesn't use
2000    get_next_as_optional(), it is considerably faster than get_next_as_list()
2001    (but can only be used when the shapes are static).
2002
2003    Args:
2004      name: not used.
2005
2006    Returns:
2007      A list consisting of the next data from each device.
2008    """
2009    del name
2010    with ops.device(self._worker):
2011      return self._format_data_list_with_options(self._iterator.get_next())
2012
2013  def get_next_as_list(self, name=None):
2014    """Get next element from underlying iterator.
2015
2016    If there is no data left, a list of dummy tensors with possible batch
2017    dimensions set to 0 will be returned. Use of get_next_as_optional() and
2018    extra logic adds overhead compared to get_next_as_list_static_shapes(), but
2019    allows us to handle non-static shapes.
2020
2021    Args:
2022      name: not used.
2023
2024    Returns:
2025      A boolean tensor indicates whether there is any data in next element and
2026      the real data as the next element or a list of dummy tensors if no data
2027      left.
2028    """
2029    del name
2030    with ops.device(self._worker):
2031      data_list = self._format_data_list_with_options(
2032          self._iterator.get_next_as_optional())
2033      result = []
2034      for i, data in enumerate(data_list):
2035        # Place the condition op in the same device as the data so the data
2036        # doesn't need to be sent back to the worker.
2037        with ops.device(self._devices[i]):
2038          # Data will be fetched in order, so we only need to check if the first
2039          # replica has value to see whether there is data left for this single
2040          # worker.
2041          if i == 0:
2042            worker_has_value = data.has_value()
2043
2044          # pylint: disable=unnecessary-lambda
2045          # pylint: disable=cell-var-from-loop
2046          real_data = control_flow_ops.cond(
2047              data.has_value(),
2048              lambda: data.get_value(),
2049              lambda: _dummy_tensor_fn(data.element_spec),
2050              strict=True,
2051          )
2052          # Some dimensions in `replicas` will become unknown after we
2053          # conditionally return the real tensors or the dummy tensors. Recover
2054          # the shapes from `data.element_spec`. We only need to do this in
2055          # non eager mode because we always know the runtime shape of the
2056          # tensors in eager mode.
2057          if not context.executing_eagerly():
2058            real_data = _recover_shape_fn(real_data, data.element_spec)
2059          result.append(real_data)
2060          # pylint: enable=cell-var-from-loop
2061          # pylint: enable=unnecessary-lambda
2062
2063      return worker_has_value, result
2064
2065
2066class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
2067  """Type specification for `_SingleWorkerOwnedDatasetIterator`."""
2068
2069  __slots__ = [
2070      "_worker", "_devices", "_element_spec", "_options",
2071      "_canonicalize_devices"
2072  ]
2073
2074  def __init__(self, worker, devices, element_spec, options,
2075               canonicalize_devices=True):
2076    self._worker = worker
2077    if canonicalize_devices:
2078      self._devices = tuple(device_util.canonicalize(d) for d in devices)
2079    else:
2080      self._devices = tuple(
2081          device_util.canonicalize_without_job_and_task(d) for d in devices)
2082    self._element_spec = element_spec
2083    # `self._options` intentionally made not `None` for proper serialization.
2084    self._options = (options if options is not None else
2085                     distribute_lib.InputOptions())
2086    self._canonicalize_devices = canonicalize_devices
2087
2088  @property
2089  def value_type(self):
2090    return _SingleWorkerOwnedDatasetIterator
2091
2092  def _serialize(self):
2093    return (self._worker, self._devices, self._element_spec, self._options,
2094            self._canonicalize_devices)
2095
2096  def _get_multi_device_iterator_spec(self, specs):
2097    device_scope = device_util.canonicalize(self._worker, device_util.current())
2098    host_device = device_util.get_host_for_device(device_scope)
2099    # source_device while creating iterator governs the worker device in
2100    # iterator spec.
2101    worker = host_device
2102    specs.append(
2103        multi_device_iterator_ops.MultiDeviceIteratorSpec(
2104            self._devices, worker, element_spec=self._element_spec))
2105
2106  @property
2107  def _component_specs(self):
2108    specs = []
2109    if _should_use_multi_device_iterator(self._options):
2110      self._get_multi_device_iterator_spec(specs)
2111    else:
2112      specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
2113    return specs
2114
2115  def _to_components(self, value):
2116    return [value._iterator]  # pylint: disable=protected-access
2117
2118  def _from_components(self, components):
2119    return _SingleWorkerOwnedDatasetIterator(
2120        dataset=None,
2121        worker=self._worker,
2122        devices=self._devices,
2123        components=components,
2124        element_spec=self._element_spec,
2125        options=self._options,
2126        canonicalize_devices=self._canonicalize_devices)
2127
2128  @staticmethod
2129  def from_value(value):
2130    # pylint: disable=protected-access
2131    return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
2132                                            value._element_spec, value._options,
2133                                            value._canonicalize_devices)
2134
2135
2136class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
2137                                        composite_tensor.CompositeTensor):
2138  """Iterator for a DistributedDataset instance."""
2139
2140  def __init__(self,
2141               dataset=None,
2142               worker=None,
2143               devices=None,
2144               components=None,
2145               element_spec=None,
2146               options=None,
2147               canonicalize_devices=None):
2148    """Create iterator for the `dataset` to fetch data to worker's `devices` .
2149
2150    `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
2151    given worker. The lifetime of this iterator is tied to the encompassing
2152    python object. Once we go out of scope of the python object or return from
2153    a tf.function the underlying iterator resource is deleted.
2154
2155    Args:
2156      dataset: A `tf.data.Dataset` instance.
2157      worker: Worker on which ops should be created.
2158      devices: Distribute data from `dataset` to these devices.
2159      components: Tensor components to construct the
2160        _SingleWorkerOwnedDatasetIterator from.
2161      element_spec: A nested structure of `TypeSpec` objects that represents the
2162      type specification of elements of the iterator.
2163      options: `tf.distribute.InputOptions` used to control options on how this
2164      dataset is distributed.
2165      canonicalize_devices: Whether to canonicalize devices for workers fully or
2166      partially. If False, it will partially canonicalize devices by removing
2167      job and task.
2168    """
2169    if worker is None or devices is None:
2170      raise ValueError("Both `worker` and `devices` should be provided")
2171
2172    error_message = ("Either `dataset` or both `components` and `element_spec` "
2173                     "need to be provided.")
2174
2175    self._options = options
2176    self._canonicalize_devices = canonicalize_devices
2177    if dataset is None:
2178      if (components is None or element_spec is None):
2179        raise ValueError(error_message)
2180      self._element_spec = element_spec
2181      self._worker = worker
2182      self._devices = devices
2183      self._iterator = components[0]
2184    else:
2185      if (components is not None or element_spec is not None):
2186        raise ValueError(error_message)
2187      super(_SingleWorkerOwnedDatasetIterator,
2188            self).__init__(dataset, worker, devices, self._options)
2189
2190  def _create_owned_multi_device_iterator(self):
2191    # If the worker devices are already canonicalized, canonicalizing again
2192    # would have no impact.
2193    # For strategies running on remote workers such as PS Strategy, the device
2194    # scope will be derived from current worker, if used under init_scope().
2195    device_scope = device_util.canonicalize(self._worker,
2196                                            device_util.current())
2197    host_device = device_util.get_host_for_device(device_scope)
2198    with ops.device(device_scope):
2199      if self._options is not None:
2200        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
2201            self._dataset,
2202            self._devices,
2203            source_device=host_device,
2204            max_buffer_size=self._options
2205            .experimental_per_replica_buffer_size,
2206            prefetch_buffer_size=self._options
2207            .experimental_per_replica_buffer_size)
2208      else:
2209        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
2210            self._dataset, self._devices, source_device=host_device)
2211
2212  def _make_iterator(self):
2213    """Make appropriate iterator on the dataset."""
2214    if not self._worker:
2215      raise ValueError("Worker device must be specified when creating an "
2216                       "owned iterator.")
2217    if _should_use_multi_device_iterator(self._options):
2218      self._create_owned_multi_device_iterator()
2219    else:
2220      with ops.device(self._worker):
2221        self._iterator = iter(self._dataset)
2222
2223  @property
2224  def element_spec(self):
2225    return self._element_spec
2226
2227  @property
2228  def _type_spec(self):
2229    return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
2230                                            self._element_spec, self._options,
2231                                            self._canonicalize_devices)
2232
2233  @property
2234  def output_classes(self):
2235    """Returns the class of each component of an element of this iterator.
2236
2237    The expected values are `tf.Tensor` and `tf.SparseTensor`.
2238
2239    Returns:
2240      A nested structure of Python `type` objects corresponding to each
2241      component of an element of this dataset.
2242    """
2243    return nest.map_structure(
2244        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2245        self._element_spec)
2246
2247  @property
2248  def output_shapes(self):
2249    """Returns the shape of each component of an element of this iterator.
2250
2251    Returns:
2252      A nested structure of `tf.TensorShape` objects corresponding to each
2253      component of an element of this dataset.
2254    """
2255    return nest.map_structure(
2256        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2257        self._element_spec)
2258
2259  @property
2260  def output_types(self):
2261    """Returns the type of each component of an element of this iterator.
2262
2263    Returns:
2264      A nested structure of `tf.DType` objects corresponding to each component
2265      of an element of this dataset.
2266    """
2267    return nest.map_structure(
2268        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2269        self._element_spec)
2270
2271
2272class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
2273  """Iterator for a single DistributedDatasetV1 instance."""
2274
2275  def _make_iterator(self):
2276    """Make appropriate iterator on the dataset."""
2277    with ops.device(self._worker):
2278      if self._options is not None:
2279        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
2280            self._dataset,
2281            self._devices,
2282            max_buffer_size=self._options.experimental_per_replica_buffer_size,
2283            prefetch_buffer_size=self._options
2284            .experimental_per_replica_buffer_size)
2285      else:
2286        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
2287            self._dataset,
2288            self._devices,
2289        )
2290
2291  def initialize(self):
2292    """Initialize underlying iterator.
2293
2294    In eager execution, this simply recreates the underlying iterator.
2295    In graph execution, it returns the initializer ops for the underlying
2296    iterator.
2297
2298    Returns:
2299      A list of any initializer ops that should be run.
2300    """
2301    if ops.executing_eagerly_outside_functions():
2302      self._iterator._eager_reset()  # pylint: disable=protected-access
2303      return []
2304    else:
2305      return [self._iterator.initializer]
2306
2307  @property
2308  def output_classes(self):
2309    return dataset_ops.get_legacy_output_classes(self._iterator)
2310
2311  @property
2312  def output_shapes(self):
2313    return dataset_ops.get_legacy_output_shapes(self._iterator)
2314
2315  @property
2316  def output_types(self):
2317    return dataset_ops.get_legacy_output_types(self._iterator)
2318
2319
2320class _SingleWorkerCallableIterator(object):
2321  """Iterator for a single tensor-returning callable."""
2322
2323  def __init__(self, fn, worker, devices):
2324    self._fn = fn
2325    self._worker = worker
2326    self._devices = devices
2327
2328  def get_next(self, device, name=None):
2329    """Get next element for the given device from the callable."""
2330    del device, name
2331    with ops.device(self._worker):
2332      return self._fn()
2333
2334  def get_next_as_list_static_shapes(self, name=None):
2335    """Get next element from the callable."""
2336    del name
2337    with ops.device(self._worker):
2338      data_list = [self._fn() for _ in self._devices]
2339      return data_list
2340
2341  def get_next_as_list(self, name=None):
2342    """Get next element from the callable."""
2343    del name
2344    with ops.device(self._worker):
2345      data_list = [self._fn() for _ in self._devices]
2346      return constant_op.constant(True), data_list
2347
2348  def initialize(self):
2349    # TODO(petebu) Should this throw an exception instead?
2350    return []
2351
2352
2353def _create_iterators_per_worker(worker_datasets,
2354                                 input_workers,
2355                                 enable_legacy_iterators,
2356                                 options=None,
2357                                 canonicalize_devices=False):
2358  """Create a multidevice iterator on each of the workers."""
2359  assert isinstance(input_workers, InputWorkers)
2360  assert len(worker_datasets) == len(input_workers.worker_devices)
2361  iterators = []
2362  for i, worker in enumerate(input_workers.worker_devices):
2363    with ops.device(worker):
2364      worker_devices = input_workers.compute_devices_for_worker(i)
2365      if tf2.enabled() and not enable_legacy_iterators:
2366        iterator = _SingleWorkerOwnedDatasetIterator(
2367            dataset=worker_datasets[i],
2368            worker=worker,
2369            devices=worker_devices,
2370            options=options,
2371            canonicalize_devices=canonicalize_devices)
2372      else:
2373        iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
2374                                                worker_devices, options)
2375      iterators.append(iterator)
2376  return iterators
2377
2378
2379def _create_datasets_from_function_with_input_context(input_contexts,
2380                                                      input_workers,
2381                                                      dataset_fn):
2382  """Create device datasets per worker given a dataset function."""
2383  datasets = []
2384  for i, ctx in enumerate(input_contexts):
2385    worker = input_workers.worker_devices[i]
2386    with ops.device(worker):
2387      dataset = dataset_fn(ctx)
2388      datasets.append(dataset)
2389  return datasets, dataset.element_spec
2390
2391
2392# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
2393def _get_batched_dataset(d):
2394  """Get the batched dataset from `d`."""
2395  # pylint: disable=protected-access
2396  if isinstance(d, dataset_ops.DatasetV1Adapter):
2397    d = d._dataset
2398
2399  if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
2400    return d
2401  elif isinstance(d, (dataset_ops.PrefetchDataset,
2402                      dataset_ops._OptionsDataset)):
2403    return _get_batched_dataset(d._input_dataset)
2404
2405  raise ValueError(
2406      "Unable to get batched dataset from the input dataset. `batch` "
2407      "`map_and_batch` need to be the last operations on the dataset. "
2408      "The batch operations can be followed by a prefetch.")
2409
2410
2411def _get_batched_dataset_attributes(d):
2412  """Get `batch_size`, `drop_remainder` of dataset."""
2413  # pylint: disable=protected-access
2414  assert isinstance(d,
2415                    (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
2416  if isinstance(d, dataset_ops.BatchDataset):
2417    batch_size = d._batch_size
2418    drop_remainder = d._drop_remainder
2419  elif isinstance(d, batching._MapAndBatchDataset):
2420    batch_size = d._batch_size_t
2421    drop_remainder = d._drop_remainder_t
2422  # pylint: enable=protected-access
2423
2424  if tensor_util.is_tf_type(batch_size):
2425    batch_size = tensor_util.constant_value(batch_size)
2426
2427  if tensor_util.is_tf_type(drop_remainder):
2428    drop_remainder = tensor_util.constant_value(drop_remainder)
2429
2430  return batch_size, drop_remainder
2431
2432
2433# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
2434def _get_dataset_attributes(dataset):
2435  """Get the underlying attributes from the dataset object."""
2436  # pylint: disable=protected-access
2437
2438  # First, get batch_size and drop_remainder from the dataset. We need
2439  # to walk back the dataset creation process and find the batched version in
2440  # order to get the attributes.
2441  batched_dataset = _get_batched_dataset(dataset)
2442  batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
2443
2444  # Second, prefetch buffer should be get from the original dataset.
2445  prefetch_buffer = None
2446  if isinstance(dataset, dataset_ops.PrefetchDataset):
2447    prefetch_buffer = dataset._buffer_size
2448  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
2449        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
2450    prefetch_buffer = dataset._dataset._buffer_size
2451
2452  return batch_size, drop_remainder, prefetch_buffer
2453
2454
2455def _should_use_multi_device_iterator(options):
2456  """Determine whether to use multi_device_iterator_ops."""
2457  if (options is None or
2458      options.experimental_replication_mode == InputReplicationMode.PER_WORKER
2459      or
2460      (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
2461       and options.experimental_fetch_to_device)):
2462    return True
2463  return False
2464
2465
2466class MultiStepContext(object):
2467  """A context object that can be used to capture things when running steps.
2468
2469  This context object is useful when running multiple steps at a time using the
2470  `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
2471  function to specify which outputs to emit at what frequency. Currently it
2472  supports capturing output from the last step, as well as capturing non tensor
2473  outputs.  In the future it will be augmented to support other use cases such
2474  as output each N steps.
2475  """
2476
2477  def __init__(self):
2478    """Initialize an output context.
2479
2480    Returns:
2481      A context object.
2482    """
2483    self._last_step_outputs = {}
2484    self._last_step_outputs_reduce_ops = {}
2485    self._non_tensor_outputs = {}
2486
2487  @property
2488  def last_step_outputs(self):
2489    """A dictionary consisting of outputs to be captured on last step.
2490
2491    Keys in the dictionary are names of tensors to be captured, as specified
2492    when `set_last_step_output` is called.
2493    Values in the dictionary are the tensors themselves. If
2494    `set_last_step_output` was called with a `reduce_op` for this output,
2495    then the value is the reduced value.
2496
2497    Returns:
2498      A dictionary with last step outputs.
2499    """
2500    return self._last_step_outputs
2501
2502  def _set_last_step_outputs(self, outputs):
2503    """Replace the entire dictionary of last step outputs."""
2504    if not isinstance(outputs, dict):
2505      raise ValueError("Need a dictionary to set last_step_outputs.")
2506    self._last_step_outputs = outputs
2507
2508  def set_last_step_output(self, name, output, reduce_op=None):
2509    """Set `output` with `name` to be outputted from the last step.
2510
2511    Args:
2512      name: String, name to identify the output. Doesn't need to match tensor
2513        name.
2514      output: The tensors that should be outputted with `name`. See below for
2515        actual types supported.
2516      reduce_op: Reduction method to use to reduce outputs from multiple
2517        replicas. Required if `set_last_step_output` is called in a replica
2518        context. Optional in cross_replica_context.
2519        When present, the outputs from all the replicas are reduced using the
2520        current distribution strategy's `reduce` method. Hence, the type of
2521        `output` must be what's supported by the corresponding `reduce` method.
2522        For e.g. if using MirroredStrategy and reduction is set, output
2523        must be a `PerReplica` value.
2524        The reduce method is also recorded in a dictionary
2525        `_last_step_outputs_reduce_ops` for later interpreting of the
2526        outputs as already reduced or not.
2527    """
2528    if distribution_strategy_context.in_cross_replica_context():
2529      self._last_step_outputs_reduce_ops[name] = reduce_op
2530      if reduce_op is None:
2531        self._last_step_outputs[name] = output
2532      else:
2533        distribution = distribution_strategy_context.get_strategy()
2534        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
2535                                                            axis=None)
2536    else:
2537      assert reduce_op is not None
2538      def merge_fn(distribution, value):
2539        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
2540                                                            axis=None)
2541        # Setting this inside the `merge_fn` because all replicas share the same
2542        # context object, so it's more robust to set it only once (even if all
2543        # the replicas are trying to set the same value).
2544        self._last_step_outputs_reduce_ops[name] = reduce_op
2545
2546      distribution_strategy_context.get_replica_context().merge_call(
2547          merge_fn, args=(output,))
2548
2549  @property
2550  def non_tensor_outputs(self):
2551    """A dictionary consisting of any non tensor outputs to be captured."""
2552    return self._non_tensor_outputs
2553
2554  def set_non_tensor_output(self, name, output):
2555    """Set `output` with `name` to be captured as a non tensor output."""
2556    if distribution_strategy_context.in_cross_replica_context():
2557      self._non_tensor_outputs[name] = output
2558    else:
2559      def merge_fn(distribution, value):
2560        # NOTE(priyag): For non tensor outputs, we simply return all the values
2561        # in a list as reduction doesn't make sense on non tensors.
2562        self._non_tensor_outputs[name] = (
2563            distribution.experimental_local_results(value))
2564      distribution_strategy_context.get_replica_context().merge_call(
2565          merge_fn, args=(output,))
2566
2567
2568def _create_distributed_tensor_spec(strategy, tensor_spec):
2569  """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
2570
2571  Args:
2572    strategy: The given `tf.distribute` strategy.
2573    tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
2574      shape should be None if you have partial batches.
2575
2576  Returns:
2577    A `tf.TypeSpec` that matches the values produced by a given strategy. This
2578    can be a `tf.TensorSpec` or a `PerRelicaSpec`.
2579  """
2580  num_replicas = len(strategy.extended.worker_devices)
2581
2582  # For one device strategy that is not MultiWorkerMirroredStrategy,  return the
2583  # tensor_spec as is, since we don't wrap the output with PerReplica in this
2584  # case.
2585  # TODO(b/166464552): remove after we always wrap for all strategies.
2586  if not _always_wrap(strategy):
2587    return tensor_spec
2588
2589  # For other cases we assume the input to tf.function is a per replica type.
2590  def _get_value_per_replica(tensor_spec_per_input):
2591    value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
2592    return values.PerReplicaSpec(*value_specs)
2593
2594  return nest.map_structure(_get_value_per_replica, tensor_spec)
2595
2596
2597def _replace_per_replica_spec(spec, i):
2598  """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
2599  if isinstance(spec, values.PerReplicaSpec):
2600    return spec._value_specs[i]  # pylint: disable=protected-access
2601  else:
2602    return spec
2603
2604
2605def _enable_get_next_as_optional(strategy, dataset):
2606  """Returns whether to enable using partial batch handling."""
2607  # TODO(b/133073708): we currently need a flag to control the usage because
2608  # there is a performance difference between get_next() and
2609  # get_next_as_optional(). And we only enable get_next_as_optional when the
2610  # output shapes are not static.
2611  #
2612  # TODO(rxsang): We want to always enable the get_next_as_optional behavior
2613  # when user passed input_fn instead of dataset.
2614  if not getattr(
2615      strategy.extended, "enable_partial_batch_handling",
2616      getattr(strategy.extended, "experimental_enable_get_next_as_optional",
2617              False)):
2618    return False
2619
2620  if context.executing_eagerly():
2621    # If the dataset is infinite, we don't need to enable last partial batch
2622    # support. Currently the logic only applies to the case that distributed
2623    # dataset is created in eager mode, as we need to evaluate the dataset
2624    # cardinality.
2625    with ops.device(dataset._variant_tensor.device):  # pylint: disable=protected-access
2626      if dataset.cardinality().numpy() == cardinality.INFINITE:
2627        return False
2628
2629  return not _is_statically_shaped(
2630      dataset.element_spec) or strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2631
2632
2633def _create_per_replica(value_list, strategy):
2634  """Creates a PerReplica.
2635
2636  For strategies other than OneDeviceStrategy, it creates a PerReplica whose
2637  type spec is set to the element spec of the dataset. This helps avoid
2638  retracing for partial batches. Retracing is problematic for multi client when
2639  different client retraces different time, since retracing changes the
2640  collective keys in the tf.function, and causes mismatches among clients.
2641
2642  For single client strategies, this simply calls distribute_utils.regroup().
2643
2644  Args:
2645    value_list: a list of values, one for each replica.
2646    strategy: the `tf.distribute.Strategy`.
2647
2648  Returns:
2649    a structure of PerReplica.
2650
2651  """
2652  # TODO(b/166464552): always wrap for all one device strategies as well.
2653  always_wrap = _always_wrap(strategy)
2654  per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
2655  return per_replicas
2656
2657
2658def _always_wrap(strategy):
2659  """Returns whether to always wrap the values in a DistributedValues."""
2660  return strategy.extended._in_multi_worker_mode() or len(  # pylint: disable=protected-access
2661      strategy.extended.worker_devices) > 1
2662
2663
2664def _rebatch_as_dynamic(per_replica_spec):
2665  """Rebatch the spec to have a dynamic batch dimension."""
2666  assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
2667
2668  # pylint: disable=protected-access
2669  def _rebatch(spec):
2670    # Rebatch if possible.
2671    try:
2672      return spec._unbatch()._batch(None)
2673    except ValueError:
2674      pass
2675    return spec
2676
2677  return values.PerReplicaSpec(
2678      *nest.map_structure(_rebatch, per_replica_spec._value_specs))
2679  # pylint: enable=protected-access
2680