• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import sys
22
23import six
24
25from tensorflow.python.data.experimental.ops import batching
26from tensorflow.python.data.experimental.ops import distribute
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.data.ops import multi_device_iterator_ops
29from tensorflow.python.distribute import device_util
30from tensorflow.python.distribute import distribution_strategy_context
31from tensorflow.python.distribute import input_ops
32from tensorflow.python.distribute import reduce_util
33from tensorflow.python.distribute import values
34from tensorflow.python.eager import context
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import device as tf_device
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.framework import tensor_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops.ragged import ragged_tensor
47from tensorflow.python.util import nest
48from tensorflow.python.util.deprecation import deprecated
49
50
51def get_distributed_dataset(dataset,
52                            input_workers,
53                            strategy,
54                            split_batch_by=None,
55                            input_context=None):
56  """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
57
58  This is a common function that is used by all strategies to return the right
59  tf.data.Dataset wrapped instance depending on the `dataset` argument type.
60
61  Args:
62    dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance.
63    input_workers: an InputWorkers object which specifies devices on which
64        iterators should be created.
65    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
66        handle last partial batch.
67    split_batch_by: Optional integer. If present, we "split" each batch of the
68        dataset by `split_batch_by` value.
69    input_context: `InputContext` for sharding. Only pass this in for between
70        graph multi-worker cases where there is only one `input_worker`. In
71        these cases, we will shard based on the `input_pipeline_id` and
72        `num_input_pipelines` in the `InputContext`.
73
74  Returns:
75    A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
76  """
77  if isinstance(dataset, dataset_ops.DatasetV1):
78    return DistributedDatasetV1(
79        dataset,
80        input_workers,
81        strategy,
82        split_batch_by=split_batch_by,
83        input_context=input_context)
84  else:
85    return DistributedDataset(
86        dataset,
87        input_workers,
88        strategy,
89        split_batch_by=split_batch_by,
90        input_context=input_context)
91
92
93def get_distributed_datasets_from_function(dataset_fn,
94                                           input_workers,
95                                           input_contexts,
96                                           strategy):
97  """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
98
99  This is a common function that is used by all strategies to return the right
100  tf.data.Dataset wrapped instance depending on if we are in graph or eager
101  mode.
102
103  Args:
104    dataset_fn: a function that returns a tf.data.DatasetV1 or tf.data.DatasetV2
105        instance.
106    input_workers: an InputWorkers object which specifies devices on which
107        iterators should be created.
108    input_contexts: A list of `InputContext` instances to be passed to call(s)
109        to `dataset_fn`. Length and order should match worker order in
110        `worker_device_pairs`.
111    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
112        handle last partial batch.
113
114  Returns:
115    A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
116  """
117  if ops.executing_eagerly_outside_functions():
118    return DistributedDatasetsFromFunction(
119        dataset_fn,
120        input_workers,
121        input_contexts,
122        strategy)
123  else:
124    return DistributedDatasetsFromFunctionV1(
125        dataset_fn,
126        input_workers,
127        input_contexts,
128        strategy)
129
130
131class InputWorkers(object):
132  """A 1-to-many mapping from input worker devices to compute devices."""
133
134  def __init__(self, worker_device_pairs):
135    """Initialize an `InputWorkers` object.
136
137    Args:
138      worker_device_pairs: A sequence of pairs:
139        `(input device, a tuple of compute devices fed by that input device)`.
140    """
141    self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
142    self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
143                              for _, f in worker_device_pairs)
144
145  @property
146  def num_workers(self):
147    return len(self._input_worker_devices)
148
149  @property
150  def worker_devices(self):
151    return self._input_worker_devices
152
153  def compute_devices_for_worker(self, worker_index):
154    return self._fed_devices[worker_index]
155
156  def __repr__(self):
157    devices = self.worker_devices
158    debug_repr = ",\n".join("  %d %s: %s" %
159                            (i, devices[i], self._fed_devices[i])
160                            for i in range(len(devices)))
161    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
162
163
164def _get_next_as_optional(iterator, strategy, name=None):
165  """Returns an empty dataset indicator and the next input from the iterator."""
166  replicas = []
167  worker_has_values = []
168  worker_devices = []
169  for i, worker in enumerate(iterator._input_workers.worker_devices):  # pylint: disable=protected-access
170    if name is not None:
171      d = tf_device.DeviceSpec.from_string(worker)
172      new_name = "%s_%s_%d" % (name, d.job, d.task)
173    else:
174      new_name = None
175
176    with ops.device(worker):
177      worker_has_value, next_element = (
178          iterator._iterators[i].get_next_as_list(new_name))  # pylint: disable=protected-access
179      # Collective all-reduce requires explict devices for inputs.
180      with ops.device("/cpu:0"):
181        # Converting to integers for all-reduce.
182        worker_has_value = math_ops.cast(worker_has_value, dtypes.int32)
183        worker_devices.append(worker_has_value.device)
184        worker_has_values.append(worker_has_value)
185      # Make `replicas` a flat list of values across all replicas.
186      replicas.append(next_element)
187
188  # Run an all-reduce to see whether any worker has values.
189  # TODO(b/131423105): we should be able to short-cut the all-reduce in some
190  # cases.
191  if getattr(strategy.extended, "_support_per_replica_values", True):
192    # Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even
193    # though it doesn't actually have a value per replica.
194    worker_has_values = values.PerReplica(worker_has_values)
195    global_has_value = strategy.reduce(
196        reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
197  else:
198    assert len(worker_has_values) == 1
199    global_has_value = worker_has_values[0]
200  global_has_value = array_ops.reshape(
201      math_ops.cast(global_has_value, dtypes.bool), [])
202  return global_has_value, replicas
203
204
205class DistributedIterator(object):
206  """Common implementation for all input iterators."""
207
208  def __init__(self, input_workers, iterators, strategy):
209    static_shape = True
210    for iterator in iterators:
211      if not isinstance(iterator, _SingleWorkerDatasetIterator):
212        continue
213      flattened_shapes = nest.flatten(iterator.output_shapes)
214      for output_shape in flattened_shapes:
215        if not output_shape.is_fully_defined():
216          static_shape = False
217          break
218
219    # TODO(b/133073708): we currently need a flag to control the usage because
220    # there is a performance difference between get_next() and
221    # get_next_as_optional(). And we only enable get_next_as_optional when the
222    # output shapes are not static.
223    #
224    # TODO(yuefengz): Currently `experimental_enable_get_next_as_optional` is
225    # always set to False in CollectiveAllReduceStrategy. We want to have a way
226    # to distinguish multi workers/single worker between graph, so we can enable
227    # the behavior in single worker case.
228    #
229    # TODO(rxsang): We want to always enable the get_next_as_optional behavior
230    # when user passed input_fn instead of dataset.
231    if getattr(
232        strategy.extended, "experimental_enable_get_next_as_optional", False):
233      self._enable_get_next_as_optional = not static_shape
234    else:
235      self._enable_get_next_as_optional = False
236
237    assert isinstance(input_workers, InputWorkers)
238    if not input_workers.worker_devices:
239      raise ValueError("Should have at least one worker for input iterator.")
240
241    self._iterators = iterators
242    self._input_workers = input_workers
243    self._strategy = strategy
244
245  def next(self):
246    return self.__next__()
247
248  def __next__(self):
249    try:
250      return self.get_next()
251    except errors.OutOfRangeError:
252      raise StopIteration
253
254  def __iter__(self):
255    return self
256
257  def get_next(self, name=None):
258    """Returns the next input from the iterator for all replicas."""
259    if not self._enable_get_next_as_optional:
260      replicas = []
261      for i, worker in enumerate(self._input_workers.worker_devices):
262        if name is not None:
263          d = tf_device.DeviceSpec.from_string(worker)
264          new_name = "%s_%s_%d" % (name, d.job, d.task)
265        else:
266          new_name = None
267        with ops.device(worker):
268          # Make `replicas` a flat list of values across all replicas.
269          replicas.extend(
270              self._iterators[i].get_next_as_list_static_shapes(new_name))
271      return values.regroup(replicas)
272
273    out_of_range_replicas = []
274    def out_of_range_fn(worker_index, device):
275      """This function will throw an OutOfRange error."""
276      # As this will be only called when there is no data left, so calling
277      # get_next() will trigger an OutOfRange error.
278      data = self._iterators[worker_index].get_next(device)
279      out_of_range_replicas.append(data)
280      return data
281
282    global_has_value, replicas = _get_next_as_optional(self, self._strategy)
283    results = []
284    for i, worker in enumerate(self._input_workers.worker_devices):
285      with ops.device(worker):
286        devices = self._input_workers.compute_devices_for_worker(i)
287        for j, device in enumerate(devices):
288          with ops.device(device):
289            # pylint: disable=undefined-loop-variable
290            # pylint: disable=cell-var-from-loop
291            # It is fine for the lambda to capture variables from the loop as
292            # the lambda is executed in the loop as well.
293            result = control_flow_ops.cond(
294                global_has_value,
295                lambda: replicas[i][j],
296                lambda: out_of_range_fn(i, device),
297                strict=True,
298            )
299            # pylint: enable=cell-var-from-loop
300            # pylint: enable=undefined-loop-variable
301            results.append(result)
302    replicas = results
303
304    # Some dimensions in `replicas` will become unknown after we conditionally
305    # return the real tensors or the dummy tensors. We fix the input shapes by
306    # using the shapes from `out_of_range_replicas` because it is calling
307    # get_next() inside.
308    flattened_replicas = nest.flatten(replicas)
309    for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
310      for target, source in zip(
311          nest.flatten(flattened_replicas[i], expand_composites=True),
312          nest.flatten(replica_data, expand_composites=True)):
313        target.set_shape(source.get_shape())
314      # `SparseTensor` shape is not determined by the shape of its component
315      # tensors. Rather, its shape depends on a tensor's values.
316      if sparse_tensor.is_sparse(replica_data) and replica_data.get_shape():
317        dense_shape = replica_data.get_shape()
318        with ops.device(flattened_replicas[i].op.device):
319          # For partially defined shapes, fill in missing values from tensor.
320          if not dense_shape.is_fully_defined():
321            dense_shape = array_ops.stack([
322                flattened_replicas[i].dense_shape[j] if dim is None else dim
323                for j, dim in enumerate(dense_shape.as_list())
324            ])
325          flattened_replicas[i] = sparse_tensor.SparseTensor(
326              indices=flattened_replicas[i].indices,
327              values=flattened_replicas[i].values,
328              dense_shape=dense_shape)
329    replicas = nest.pack_sequence_as(replicas, flattened_replicas)
330
331    return values.regroup(replicas)
332
333  # We need a private initializer method for re-initializing multidevice
334  # iterators when used with Keras training loops. If we don't reinitialize the
335  # iterator we run into memory leak issues (b/123315763).
336  @property
337  def _initializer(self):
338    init_ops = []
339    for it in self._iterators:
340      init_ops.extend(it.initialize())
341    return control_flow_ops.group(init_ops)
342
343  @property
344  def element_spec(self):
345    """The type specification of an element of this iterator."""
346    return self._element_spec
347
348
349class DistributedIteratorV1(DistributedIterator):
350  """Input Iterator for tf.data.DatasetV1."""
351
352  @deprecated(None, "Use the iterator's `initializer` property instead.")
353  def initialize(self):
354    """Initialze underlying iterators.
355
356    Returns:
357      A list of any initializer ops that should be run.
358    """
359    return super(DistributedIteratorV1, self)._initializer
360
361  @property
362  def initializer(self):
363    """Returns a list of ops that initialize the iterator."""
364    return self.initialize()
365
366  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
367  @property
368  def output_classes(self):
369    return self._iterators[0].output_classes
370
371  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
372  @property
373  def output_shapes(self):
374    return self._iterators[0].output_shapes
375
376  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
377  @property
378  def output_types(self):
379    return self._iterators[0].output_types
380
381  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
382  def get_iterator(self, worker):
383    for i, w in enumerate(self._input_workers.worker_devices):
384      if worker == w:
385        return self._iterators[i]
386    return None
387
388
389class _IterableInput(object):
390  """Base class for iterable inputs for distribution strategies."""
391
392  def __init__(self, input_workers):
393    assert isinstance(input_workers, InputWorkers)
394    self._input_workers = input_workers
395
396  def __iter__(self):
397    raise NotImplementedError("must be implemented in descendants")
398
399  def reduce(self, initial_state, reduce_fn):
400    """Execute a `reduce_fn` over all the elements of the input."""
401    iterator = iter(self)
402    has_data, data = _get_next_as_optional(iterator, self._strategy)
403
404    def cond(has_data, data, state):
405      del data, state  # Unused.
406      return has_data
407
408    def loop_body(has_data, data, state):
409      """Executes `reduce_fn` in a loop till the dataset is empty."""
410      del has_data  # Unused.
411      # data is list of lists here. where each list corresponds to one worker.
412      # TODO(b/130570614): Add support for the multiworker and TPU pods use
413      # case.
414      if self._input_workers.num_workers == 1:
415        data = data[0]
416      else:
417        raise ValueError("Dataset iteration within a tf.function is"
418                         " not supported for multiple workers.")
419      state = reduce_fn(state, values.regroup(data))
420      has_data, data = _get_next_as_optional(iterator, self._strategy)
421      return has_data, data, state
422
423    has_data, data, final_state = control_flow_ops.while_loop(
424        cond, loop_body, [has_data, data, initial_state], parallel_iterations=1)
425    return final_state
426
427
428class DistributedDataset(_IterableInput):
429  """Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices."""
430
431  def __init__(self,
432               dataset,
433               input_workers,
434               strategy,
435               split_batch_by=None,
436               input_context=None):
437    """Distribute the dataset on all workers.
438
439    If `split_batch_by` is not None, we "split" each batch of the dataset by
440    `split_batch_by` value.
441
442    Args:
443      dataset: `tf.data.Dataset` that will be used as the input source.
444      input_workers: an `InputWorkers` object.
445      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
446        handle last partial batch.
447      split_batch_by: Optional integer. If present, we "split" each batch of the
448        dataset by `split_batch_by` value.
449      input_context: `InputContext` for sharding. Only pass this in for between
450        graph multi-worker cases where there is only one `input_worker`. In
451        these cases, we will shard based on the `input_pipeline_id` and
452        `num_input_pipelines` in the `InputContext`.
453    """
454    super(DistributedDataset, self).__init__(input_workers=input_workers)
455
456    # We clone and shard the dataset on each worker. The current setup tries to
457    # shard the dataset by files if possible so that each worker sees a
458    # different subset of files. If that is not possible, will attempt to shard
459    # the final input such that each worker will run the entire preprocessing
460    # pipeline and only receive its own shard of the dataset.
461    if split_batch_by:
462      try:
463        # pylint: disable=protected-access
464        with ops.colocate_with(dataset._variant_tensor):
465          dataset = distribute._RebatchDataset(dataset, split_batch_by)
466          # Add a prefetch to pipeline rebatching for performance.
467          # TODO(rachelim): Instead of inserting an extra prefetch stage here,
468          # leverage static graph rewrites to insert _RebatchDataset before
469          # the final `prefetch` if it exists.
470          dataset = dataset.prefetch(split_batch_by)
471      except errors.InvalidArgumentError as e:
472        if "without encountering a batch" in str(e):
473          six.reraise(
474              ValueError,
475              ValueError(
476                  "Call the `batch` method on the input Dataset in order to be "
477                  "able to split your input across {} replicas.\n Please "
478                  "the tf.distribute.Strategy guide. {}".format(
479                      split_batch_by, e)),
480              sys.exc_info()[2])
481        else:
482          raise
483
484    # TODO(b/138745411): Remove once stateful transformations are supported.
485    options = dataset_ops.Options()
486    options.experimental_distribute._make_stateless = True  # pylint: disable=protected-access
487    dataset = dataset.with_options(options)
488
489    self._cloned_datasets = []
490    if input_context:
491      # Between-graph where we rely on the input_context for sharding
492      assert input_workers.num_workers == 1
493      dataset = input_ops.auto_shard_dataset(dataset,
494                                             input_context.num_input_pipelines,
495                                             input_context.input_pipeline_id)
496      self._cloned_datasets.append(dataset)
497    else:
498      replicated_ds = distribute.replicate(dataset,
499                                           input_workers.worker_devices)
500      for i, worker in enumerate(input_workers.worker_devices):
501        with ops.device(worker):
502          cloned_dataset = replicated_ds[worker]
503          cloned_dataset = cloned_dataset.with_options(dataset.options())
504          cloned_dataset = input_ops.auto_shard_dataset(
505              cloned_dataset, len(input_workers.worker_devices), i)
506          self._cloned_datasets.append(cloned_dataset)
507
508    self._input_workers = input_workers
509    self._strategy = strategy
510    self._element_spec = _create_distributed_tensor_spec(self._strategy,
511                                                         dataset.element_spec)  # pylint: disable=protected-access
512
513  def __iter__(self):
514    if not (context.executing_eagerly() or
515            ops.get_default_graph().building_function):
516      raise RuntimeError("__iter__() is only supported inside of tf.function "
517                         "or when eager execution is enabled.")
518
519    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
520                                                    self._input_workers)
521    iterator = DistributedIterator(self._input_workers, worker_iterators,
522                                   self._strategy)
523    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
524    return iterator
525
526  @property
527  def element_spec(self):
528    """The type specification of an element of this dataset."""
529    return self._element_spec
530
531
532class DistributedDatasetV1(DistributedDataset):
533  """Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
534
535  def __init__(self,
536               dataset,
537               input_workers,
538               strategy,
539               split_batch_by=None,
540               input_context=None):
541    self._input_workers = input_workers
542    super(DistributedDatasetV1, self).__init__(
543        dataset,
544        input_workers,
545        strategy,
546        split_batch_by=split_batch_by,
547        input_context=input_context)
548
549  def make_one_shot_iterator(self):
550    """Get a one time use iterator for DistributedDatasetV1.
551
552    Note: This API is deprecated. Please use `for ... in dataset:` to iterate
553    over the dataset or `iter` to create an iterator.
554
555    Returns:
556      A DistributedIteratorV1 instance.
557    """
558    return self._make_one_shot_iterator()
559
560  def _make_one_shot_iterator(self):
561    """Get an iterator for DistributedDatasetV1."""
562    # Graph mode with one shot iterator is disabled because we have to call
563    # `initialize` on the iterator which is only required if we are using a
564    # tf.distribute strategy.
565    if not context.executing_eagerly():
566      raise ValueError("Cannot create a one shot iterator. Please use "
567                       "`make_initializable_iterator()` instead.")
568    return self._get_iterator()
569
570  def make_initializable_iterator(self):
571    """Get an initializable iterator for DistributedDatasetV1.
572
573    Note: This API is deprecated. Please use
574    `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
575    initializable iterator.
576
577    Returns:
578      A DistributedIteratorV1 instance.
579    """
580    return self._make_initializable_iterator()
581
582  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=unused-argument
583    """Get an initializable iterator for DistributedDatasetV1."""
584    # Eager mode generates already initialized iterators. Hence we cannot create
585    # an initializable iterator.
586    if context.executing_eagerly():
587      raise ValueError("Cannot create initializable iterator in Eager mode. "
588                       "Please use `iter()` instead.")
589    return self._get_iterator()
590
591  def _get_iterator(self):
592    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
593                                                    self._input_workers)
594    iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
595                                     self._strategy)
596    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
597    return iterator
598
599
600# TODO(priyag): Add other replication modes.
601class DistributedDatasetsFromFunction(_IterableInput):
602  """Inputs created from dataset function."""
603
604  def __init__(self, dataset_fn, input_workers, input_contexts, strategy):
605    """Makes an iterable from datasets created by the given function.
606
607    Args:
608      dataset_fn: A function that returns a `Dataset` given an `InputContext`.
609      input_workers: an `InputWorkers` object.
610      input_contexts: A list of `InputContext` instances to be passed to call(s)
611        to `dataset_fn`. Length and order should match worker order in
612        `worker_device_pairs`.
613      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
614        handle last partial batch.
615    """
616    super(DistributedDatasetsFromFunction, self).__init__(
617        input_workers=input_workers)
618
619    if input_workers.num_workers != len(input_contexts):
620      raise ValueError(
621          "Number of input workers (%d) is not same as number of "
622          "input_contexts (%d)" %
623          (input_workers.num_workers, len(input_contexts)))
624
625    self._dataset_fn = dataset_fn
626    self._input_workers = input_workers
627    self._input_contexts = input_contexts
628    self._strategy = strategy
629    self._element_spec = None
630
631  def __iter__(self):
632    if not (context.executing_eagerly() or
633            ops.get_default_graph().building_function):
634      raise RuntimeError("__iter__() is only supported inside of tf.function "
635                         "or when eager execution is enabled.")
636
637    iterators, element_spec = _create_iterators_per_worker_with_input_context(
638        self._input_contexts, self._input_workers, self._dataset_fn)
639    iterator = DistributedIterator(self._input_workers, iterators,
640                                   self._strategy)
641    self._element_spec = _create_distributed_tensor_spec(self._strategy,
642                                                         element_spec)
643    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
644    return iterator
645
646  @property
647  def element_spec(self):
648    """The type specification of an element of this dataset."""
649    if self._element_spec is None:
650      raise ValueError("You must create an iterator before calling "
651                       "`element_spec` on the distributed dataset or iterator. "
652                       "This is because the dataset function is not called "
653                       "before an iterator is created.")
654
655    return self._element_spec
656
657
658class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
659  """Inputs created from dataset function."""
660
661  def _make_initializable_iterator(self, shared_name=None):
662    """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
663    del shared_name  # Unused
664    # Eager mode generates already initialized iterators. Hence we cannot create
665    # an initializable iterator.
666    if context.executing_eagerly():
667      raise ValueError("Cannot create initializable iterator in Eager mode. "
668                       "Please use `iter()` instead.")
669    return self._get_iterator()
670
671  def _make_one_shot_iterator(self):
672    """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
673    # Graph mode with one shot iterator is disabled because we have to call
674    # `initialize` on the iterator which is only required if we are using a
675    # tf.distribute strategy.
676    if not context.executing_eagerly():
677      raise ValueError("Cannot create a one shot iterator. Please use "
678                       "`make_initializable_iterator()` instead.")
679    return self._get_iterator()
680
681  def _get_iterator(self):
682    iterators, element_spec = _create_iterators_per_worker_with_input_context(
683        self._input_contexts, self._input_workers, self._dataset_fn)
684    iterator = DistributedIteratorV1(self._input_workers, iterators,
685                                     self._strategy)
686    self._element_spec = _create_distributed_tensor_spec(self._strategy,
687                                                         element_spec)
688    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
689    return iterator
690
691
692# TODO(anjalisridhar): This class will be soon be removed in favor of newer
693# APIs.
694class InputFunctionIterator(DistributedIteratorV1):
695  """Iterator created from input function."""
696
697  def __init__(self, input_fn, input_workers, input_contexts, strategy):
698    """Make an iterator for input provided via an input function.
699
700    Currently implements PER_WORKER mode, in which the `input_fn` is called
701    once on each worker.
702
703    TODO(priyag): Add other replication modes.
704
705    Args:
706      input_fn: Input function that returns a `tf.data.Dataset` object.
707      input_workers: an `InputWorkers` object.
708      input_contexts: A list of `InputContext` instances to be passed to call(s)
709        to `input_fn`. Length and order should match worker order in
710        `worker_device_pairs`.
711      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
712        handle last partial batch.
713    """
714    assert isinstance(input_workers, InputWorkers)
715    if input_workers.num_workers != len(input_contexts):
716      raise ValueError(
717          "Number of input workers (%d) is not same as number of "
718          "input_contexts (%d)" %
719          (input_workers.num_workers, len(input_contexts)))
720
721    iterators = []
722    for i, ctx in enumerate(input_contexts):
723      worker = input_workers.worker_devices[i]
724      with ops.device(worker):
725        result = input_fn(ctx)
726        devices = input_workers.compute_devices_for_worker(i)
727        if isinstance(result, dataset_ops.DatasetV2):
728          iterator = _SingleWorkerDatasetIterator(result, worker, devices)
729        elif callable(result):
730          iterator = _SingleWorkerCallableIterator(result, worker, devices)
731        else:
732          raise ValueError(
733              "input_fn must return a tf.data.Dataset or a callable.")
734        iterators.append(iterator)
735
736    super(InputFunctionIterator, self).__init__(input_workers, iterators,
737                                                strategy)
738
739
740# TODO(anjalisridhar): This class will soon be removed and users should move
741# to using DistributedIterator.
742class DatasetIterator(DistributedIteratorV1):
743  """Iterator created from input dataset."""
744
745  def __init__(self,
746               dataset,
747               input_workers,
748               strategy,
749               split_batch_by=None,
750               input_context=None):
751    """Make an iterator for the dataset on given devices.
752
753    If `split_batch_by` is not None, we "split" each batch of the
754    dataset by `split_batch_by` value.
755
756    Args:
757      dataset: `tf.data.Dataset` that will be used as the input source.
758      input_workers: an `InputWorkers` object.
759      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
760        handle last partial batch.
761      split_batch_by: Optional integer. If present, we "split" each batch of the
762        dataset by `split_batch_by` value.
763      input_context: `InputContext` for sharding. Only pass this in for between
764        graph multi-worker cases where there is only one `input_worker`. In
765        these cases, we will shard based on the `input_pipeline_id` and
766        `num_input_pipelines` in the `InputContext`.
767    """
768    dist_dataset = DistributedDatasetV1(
769        dataset,
770        input_workers,
771        strategy,
772        split_batch_by=split_batch_by,
773        input_context=input_context)
774    worker_iterators = _create_iterators_per_worker(
775        dist_dataset._cloned_datasets, input_workers)  # pylint: disable=protected-access
776    super(DatasetIterator, self).__init__(
777        input_workers,
778        worker_iterators,  # pylint: disable=protected-access
779        strategy)
780    self._element_spec = dist_dataset.element_spec
781
782
783def _dummy_tensor_fn(value_structure):
784  """A function to create dummy tensors from `value_structure`."""
785
786  def create_dummy_tensor(type_spec):
787    """Create a dummy tensor with possible batch dimensions set to 0."""
788    if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
789      # Splice out the ragged dimensions.
790      # pylint: disable=protected-access
791      feature_shape = type_spec._shape[:1].concatenate(
792          type_spec._shape[(1 + type_spec._ragged_rank):])
793      feature_type = type_spec._dtype
794      # pylint: enable=protected-access
795    else:
796      feature_shape = type_spec.shape
797      feature_type = type_spec.dtype
798    # Ideally we should set the batch dimension to 0, however as in
799    # DistributionStrategy we don't know the batch dimension, we try to
800    # guess it as much as possible. If the feature has unknown dimensions, we
801    # will set them to 0. If the feature shape is already static, we guess the
802    # first dimension as batch dimension and set it to 0.
803    dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
804            if feature_shape else [])
805    if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or
806                 feature_shape.is_fully_defined()):
807      dims[0] = tensor_shape.Dimension(0)
808
809    if isinstance(type_spec, sparse_tensor.SparseTensorSpec):
810      return sparse_tensor.SparseTensor(
811          values=array_ops.zeros(0, feature_type),
812          indices=array_ops.zeros((0, len(dims)), dtypes.int64),
813          dense_shape=dims)
814
815    # Create the dummy tensor.
816    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
817    if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
818      # Reinsert the ragged dimensions with size 0.
819      # pylint: disable=protected-access
820      row_splits = array_ops.zeros(1, type_spec._row_splits_dtype)
821      dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
822          dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False)
823      # pylint: enable=protected-access
824    return dummy_tensor
825
826  return nest.map_structure(create_dummy_tensor, value_structure)
827
828
829class _SingleWorkerDatasetIterator(object):
830  """Iterator for a single `tf.data.Dataset`."""
831
832  def __init__(self, dataset, worker, devices):
833    """Create iterator for the `dataset` to fetch data to worker's `devices` .
834
835    `MultiDeviceIterator` is used to prefetch input to the devices on the
836    given worker.
837
838    Args:
839      dataset: A `tf.data.Dataset` instance.
840      worker: Worker on which ops should be created.
841      devices: Distribute data from `dataset` to these devices.
842    """
843    self._dataset = dataset
844    self._worker = worker
845    self._devices = devices
846    self._make_iterator()
847
848  def _make_iterator(self):
849    """Make appropriate iterator on the dataset."""
850    with ops.device(self._worker):
851      self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
852          self._dataset, self._devices)
853
854  def get_next(self, device, name=None):
855    """Get next element for the given device."""
856    del name
857    with ops.device(self._worker):
858      return self._iterator.get_next(device)
859
860  def get_next_as_list_static_shapes(self, name=None):
861    """Get next element from the underlying iterator.
862
863    Runs the iterator get_next() within a device scope. Since this doesn't use
864    get_next_as_optional(), is is considerably faster than get_next_as_list()
865    (but can only be used when the shapes are static).
866
867    Args:
868      name: not used.
869
870    Returns:
871      A list consisting of the next data from each device.
872    """
873    del name
874    with ops.device(self._worker):
875      return self._iterator.get_next()
876
877  def get_next_as_list(self, name=None):
878    """Get next element from underlying iterator.
879
880    If there is no data left, a list of dummy tensors with possible batch
881    dimensions set to 0 will be returned. Use of get_next_as_optional() and
882    extra logic adds overhead compared to get_next_as_list_static_shapes(), but
883    allows us to handle non-static shapes.
884
885    Args:
886      name: not used.
887
888    Returns:
889      A boolean tensor indicates whether there is any data in next element and
890      the real data as the next element or a list of dummy tensors if no data
891      left.
892    """
893    del name
894    with ops.device(self._worker):
895      data_list = self._iterator.get_next_as_optional()
896      result = []
897      for i, data in enumerate(data_list):
898        # Place the condition op in the same device as the data so the data
899        # doesn't need to be sent back to the worker.
900        with ops.device(self._devices[i]):
901          # As MultiDeviceIterator will fetch data in order, so we only need to
902          # check if the first replica has value to see whether there is data
903          # left for this single worker.
904          if i == 0:
905            worker_has_value = data.has_value()
906
907          # pylint: disable=unnecessary-lambda
908          # pylint: disable=cell-var-from-loop
909          real_data = control_flow_ops.cond(
910              data.has_value(),
911              lambda: data.get_value(),
912              lambda: _dummy_tensor_fn(data.value_structure),
913              strict=True,
914          )
915          result.append(real_data)
916          # pylint: enable=cell-var-from-loop
917          # pylint: enable=unnecessary-lambda
918
919      return worker_has_value, result
920
921  def initialize(self):
922    """Initialze underlying iterator.
923
924    In eager execution, this simply recreates the underlying iterator.
925    In graph execution, it returns the initializer ops for the underlying
926    iterator.
927
928    Returns:
929      A list of any initializer ops that should be run.
930    """
931    if ops.executing_eagerly_outside_functions():
932      self._iterator._eager_reset()  # pylint: disable=protected-access
933      return []
934    else:
935      return [self._iterator.initializer]
936
937  @property
938  def output_classes(self):
939    return dataset_ops.get_legacy_output_classes(self._iterator)
940
941  @property
942  def output_shapes(self):
943    return dataset_ops.get_legacy_output_shapes(self._iterator)
944
945  @property
946  def output_types(self):
947    return dataset_ops.get_legacy_output_types(self._iterator)
948
949
950class _SingleWorkerCallableIterator(object):
951  """Iterator for a single tensor-returning callable."""
952
953  def __init__(self, fn, worker, devices):
954    self._fn = fn
955    self._worker = worker
956    self._devices = devices
957
958  def get_next(self, device, name=None):
959    """Get next element for the given device from the callable."""
960    del device, name
961    with ops.device(self._worker):
962      return self._fn()
963
964  def get_next_as_list_static_shapes(self, name=None):
965    """Get next element from the callable."""
966    del name
967    with ops.device(self._worker):
968      data_list = [self._fn() for _ in self._devices]
969      return data_list
970
971  def get_next_as_list(self, name=None):
972    """Get next element from the callable."""
973    del name
974    with ops.device(self._worker):
975      data_list = [self._fn() for _ in self._devices]
976      return constant_op.constant(True), data_list
977
978  def initialize(self):
979    # TODO(petebu) Should this throw an exception instead?
980    return []
981
982
983def _create_iterators_per_worker(worker_datasets, input_workers):
984  """Create a multidevice iterator on each of the workers."""
985  assert isinstance(input_workers, InputWorkers)
986
987  assert len(worker_datasets) == len(input_workers.worker_devices)
988  iterators = []
989  for i, worker in enumerate(input_workers.worker_devices):
990    with ops.device(worker):
991      worker_devices = input_workers.compute_devices_for_worker(i)
992      iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
993                                              worker_devices)
994      iterators.append(iterator)
995  return iterators
996
997
998def _create_iterators_per_worker_with_input_context(input_contexts,
999                                                    input_workers,
1000                                                    dataset_fn):
1001  """Create a multidevice iterator per workers given a dataset function."""
1002  iterators = []
1003  for i, ctx in enumerate(input_contexts):
1004    worker = input_workers.worker_devices[i]
1005    with ops.device(worker):
1006      dataset = dataset_fn(ctx)
1007      # TODO(b/138745411): Remove once stateful transformations are supported.
1008      options = dataset_ops.Options()
1009      options.experimental_distribute._make_stateless = True  # pylint: disable=protected-access
1010      dataset = dataset.with_options(options)
1011      devices = input_workers.compute_devices_for_worker(i)
1012      iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
1013      iterators.append(iterator)
1014  return iterators, dataset.element_spec
1015
1016
1017# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1018def _get_batched_dataset(d):
1019  """Get the batched dataset from `d`."""
1020  # pylint: disable=protected-access
1021  if isinstance(d, dataset_ops.DatasetV1Adapter):
1022    d = d._dataset
1023
1024  if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
1025    return d
1026  elif isinstance(d, (dataset_ops.PrefetchDataset,
1027                      dataset_ops._OptionsDataset)):
1028    return _get_batched_dataset(d._input_dataset)
1029
1030  raise ValueError(
1031      "Unable to get batched dataset from the input dataset. `batch` "
1032      "`map_and_batch` need to be the last operations on the dataset. "
1033      "The batch operations can be followed by a prefetch.")
1034
1035
1036def _get_batched_dataset_attributes(d):
1037  """Get `batch_size`, `drop_remainder` of dataset."""
1038  # pylint: disable=protected-access
1039  assert isinstance(d,
1040                    (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
1041  if isinstance(d, dataset_ops.BatchDataset):
1042    batch_size = d._batch_size
1043    drop_remainder = d._drop_remainder
1044  elif isinstance(d, batching._MapAndBatchDataset):
1045    batch_size = d._batch_size_t
1046    drop_remainder = d._drop_remainder_t
1047  # pylint: enable=protected-access
1048
1049  if tensor_util.is_tensor(batch_size):
1050    batch_size = tensor_util.constant_value(batch_size)
1051
1052  if tensor_util.is_tensor(drop_remainder):
1053    drop_remainder = tensor_util.constant_value(drop_remainder)
1054
1055  return batch_size, drop_remainder
1056
1057
1058# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1059def _get_dataset_attributes(dataset):
1060  """Get the underlying attributes from the dataset object."""
1061  # pylint: disable=protected-access
1062
1063  # First, get batch_size and drop_remainder from the dataset. We need
1064  # to walk back the dataset creation process and find the batched version in
1065  # order to get the attributes.
1066  batched_dataset = _get_batched_dataset(dataset)
1067  batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
1068
1069  # Second, prefetch buffer should be get from the original dataset.
1070  prefetch_buffer = None
1071  if isinstance(dataset, dataset_ops.PrefetchDataset):
1072    prefetch_buffer = dataset._buffer_size
1073  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
1074        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
1075    prefetch_buffer = dataset._dataset._buffer_size
1076
1077  return batch_size, drop_remainder, prefetch_buffer
1078
1079
1080class MultiStepContext(object):
1081  """A context object that can be used to capture things when running steps.
1082
1083  This context object is useful when running multiple steps at a time using the
1084  `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
1085  function to specify which outputs to emit at what frequency. Currently it
1086  supports capturing output from the last step, as well as capturing non tensor
1087  outputs.  In the future it will be augmented to support other use cases such
1088  as output each N steps.
1089  """
1090
1091  def __init__(self):
1092    """Initialize an output context.
1093
1094    Returns:
1095      A context object.
1096    """
1097    self._last_step_outputs = {}
1098    self._last_step_outputs_reduce_ops = {}
1099    self._non_tensor_outputs = {}
1100
1101  @property
1102  def last_step_outputs(self):
1103    """A dictionary consisting of outputs to be captured on last step.
1104
1105    Keys in the dictionary are names of tensors to be captured, as specified
1106    when `set_last_step_output` is called.
1107    Values in the dictionary are the tensors themselves. If
1108    `set_last_step_output` was called with a `reduce_op` for this output,
1109    then the value is the reduced value.
1110
1111    Returns:
1112      A dictionary with last step outputs.
1113    """
1114    return self._last_step_outputs
1115
1116  def _set_last_step_outputs(self, outputs):
1117    """Replace the entire dictionary of last step outputs."""
1118    if not isinstance(outputs, dict):
1119      raise ValueError("Need a dictionary to set last_step_outputs.")
1120    self._last_step_outputs = outputs
1121
1122  def set_last_step_output(self, name, output, reduce_op=None):
1123    """Set `output` with `name` to be outputted from the last step.
1124
1125    Args:
1126      name: String, name to identify the output. Doesn't need to match tensor
1127        name.
1128      output: The tensors that should be outputted with `name`. See below for
1129        actual types supported.
1130      reduce_op: Reduction method to use to reduce outputs from multiple
1131        replicas. Required if `set_last_step_output` is called in a replica
1132        context. Optional in cross_replica_context.
1133        When present, the outputs from all the replicas are reduced using the
1134        current distribution strategy's `reduce` method. Hence, the type of
1135        `output` must be what's supported by the corresponding `reduce` method.
1136        For e.g. if using MirroredStrategy and reduction is set, output
1137        must be a `PerReplica` value.
1138        The reduce method is also recorded in a dictionary
1139        `_last_step_outputs_reduce_ops` for later interpreting of the
1140        outputs as already reduced or not.
1141    """
1142    if distribution_strategy_context.in_cross_replica_context():
1143      self._last_step_outputs_reduce_ops[name] = reduce_op
1144      if reduce_op is None:
1145        self._last_step_outputs[name] = output
1146      else:
1147        distribution = distribution_strategy_context.get_strategy()
1148        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
1149                                                            axis=None)
1150    else:
1151      assert reduce_op is not None
1152      def merge_fn(distribution, value):
1153        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
1154                                                            axis=None)
1155        # Setting this inside the `merge_fn` because all replicas share the same
1156        # context object, so it's more robust to set it only once (even if all
1157        # the replicas are trying to set the same value).
1158        self._last_step_outputs_reduce_ops[name] = reduce_op
1159
1160      distribution_strategy_context.get_replica_context().merge_call(
1161          merge_fn, args=(output,))
1162
1163  @property
1164  def non_tensor_outputs(self):
1165    """A dictionary consisting of any non tensor outputs to be captured."""
1166    return self._non_tensor_outputs
1167
1168  def set_non_tensor_output(self, name, output):
1169    """Set `output` with `name` to be captured as a non tensor output."""
1170    if distribution_strategy_context.in_cross_replica_context():
1171      self._non_tensor_outputs[name] = output
1172    else:
1173      def merge_fn(distribution, value):
1174        # NOTE(priyag): For non tensor outputs, we simply return all the values
1175        # in a list as reduction doesn't make sense on non tensors.
1176        self._non_tensor_outputs[name] = (
1177            distribution.experimental_local_results(value))
1178      distribution_strategy_context.get_replica_context().merge_call(
1179          merge_fn, args=(output,))
1180
1181
1182def _create_distributed_tensor_spec(strategy, tensor_spec):
1183  """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
1184
1185  Args:
1186    strategy: The given `tf.distribute` strategy.
1187    tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
1188      shape should be None if you have partial batches.
1189
1190  Returns:
1191    A `tf.TypeSpec` that matches the values produced by a given strategy. This
1192    can be a `tf.TensorSpec` or a `PerRelicaSpec`.
1193  """
1194  num_replicas = len(strategy.extended.worker_devices)
1195
1196  # If the number of devices used in the strategy is just 1 then we return
1197  # the tensor_spec as is.
1198  if num_replicas == 1:
1199    return tensor_spec
1200
1201  # If the number of devices is greater than 1 then we assume the input to
1202  # tf.function is a per replica type.
1203  def _get_value_per_replica(tensor_spec_per_input):
1204    value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
1205    return values.PerReplicaSpec(*value_specs)
1206
1207  return nest.map_structure(_get_value_per_replica, tensor_spec)
1208
1209