• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.data.ops import options as options_lib
23from tensorflow.python.data.util import structure
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import composite_tensor
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import type_spec
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import functional_ops
35from tensorflow.python.ops import gen_dataset_ops
36from tensorflow.python.ops import resource_variable_ops
37
38
39class _PerDeviceGenerator(dataset_ops.DatasetV2):
40  """A `dummy` generator dataset."""
41
42  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
43               source_device, element_spec):
44    self._element_spec = element_spec
45
46    multi_device_iterator_string_handle = (
47        gen_dataset_ops.multi_device_iterator_to_string_handle(
48            multi_device_iterator_resource))
49
50    # TODO(b/124254153): Enable autograph once the overhead is low enough.
51    @function.defun(autograph=False)  # Pure graph code.
52    def _init_func():
53      return multi_device_iterator_string_handle
54
55    init_func_concrete = _init_func.get_concrete_function()
56
57    # TODO(b/124254153): Enable autograph once the overhead is low enough.
58    @function.defun(autograph=False)  # Pure graph code.
59    def _remote_init_func():
60      return functional_ops.remote_call(
61          target=source_device,
62          args=init_func_concrete.captured_inputs,
63          Tout=[dtypes.string],
64          f=init_func_concrete)
65
66    self._init_func = _remote_init_func.get_concrete_function()
67    self._init_captured_args = self._init_func.captured_inputs
68
69    # TODO(b/124254153): Enable autograph once the overhead is low enough.
70    @function.defun(
71        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
72        autograph=False)  # Pure graph code.
73    def _next_func(string_handle):
74      # pylint: disable=protected-access
75      multi_device_iterator = (
76          gen_dataset_ops.multi_device_iterator_from_string_handle(
77              string_handle=string_handle,
78              output_types=structure.get_flat_tensor_types(self._element_spec),
79              output_shapes=structure.get_flat_tensor_shapes(
80                  self._element_spec)))
81      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
82          multi_device_iterator=multi_device_iterator,
83          shard_num=shard_num,
84          incarnation_id=incarnation_id,
85          output_types=structure.get_flat_tensor_types(self._element_spec),
86          output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
87
88    next_func_concrete = _next_func.get_concrete_function()
89
90    # TODO(b/124254153): Enable autograph once the overhead is low enough.
91    @function.defun_with_attributes(
92        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
93        attributes={"experimental_ints_on_device": True},
94        autograph=False)  # Pure graph code.
95    def _remote_next_func(string_handle):
96      return functional_ops.remote_call(
97          target=source_device,
98          args=[string_handle] + next_func_concrete.captured_inputs,
99          Tout=structure.get_flat_tensor_types(self._element_spec),
100          f=next_func_concrete)
101
102    self._next_func = _remote_next_func.get_concrete_function()
103    self._next_captured_args = self._next_func.captured_inputs
104
105    self._incarnation_id_index = -1
106    for i, arg in enumerate(self._next_captured_args):
107      if arg is incarnation_id:
108        self._incarnation_id_index = i
109
110    # TODO(b/124254153): Enable autograph once the overhead is low enough.
111    @function.defun(
112        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
113        autograph=False)  # Pure graph code.
114    def _finalize_func(unused_string_handle):
115      return array_ops.constant(0, dtypes.int64)
116
117    finalize_func_concrete = _finalize_func.get_concrete_function()
118
119    # TODO(b/124254153): Enable autograph once the overhead is low enough.
120    @function.defun(
121        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
122        autograph=False)  # Pure graph code.
123    def _remote_finalize_func(string_handle):
124      return functional_ops.remote_call(
125          target=source_device,
126          args=[string_handle] + finalize_func_concrete.captured_inputs,
127          Tout=[dtypes.int64],
128          f=finalize_func_concrete)
129
130    self._finalize_func = _remote_finalize_func.get_concrete_function()
131    self._finalize_captured_args = self._finalize_func.captured_inputs
132
133    variant_tensor = gen_dataset_ops.generator_dataset(
134        self._init_captured_args,
135        self._next_captured_args,
136        self._finalize_captured_args,
137        init_func=self._init_func,
138        next_func=self._next_func,
139        finalize_func=self._finalize_func,
140        **self._flat_structure)
141    super(_PerDeviceGenerator, self).__init__(variant_tensor)
142
143  def _inputs(self):
144    # TODO(b/116506223): Determine which datasets should be used as inputs here.
145    return []
146
147  @property
148  def element_spec(self):
149    return self._element_spec
150
151
152class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
153  """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id.
154
155  Re-uses the functions from the provided per_device_dataset and just switches
156  out the function argument corresponding to the incarnation_id.
157  """
158
159  def __init__(self, per_device_dataset, incarnation_id):
160    # pylint: disable=protected-access
161    self._element_spec = per_device_dataset.element_spec
162    self._init_func = per_device_dataset._init_func
163    self._init_captured_args = self._init_func.captured_inputs
164
165    self._next_func = per_device_dataset._next_func
166    self._next_captured_args = per_device_dataset._next_captured_args
167    # The captured arguments to the next_func are string_handle, incarnation_id.
168    # We update the incarnation id to the new one.
169    self._next_captured_args[
170        per_device_dataset._incarnation_id_index] = incarnation_id
171
172    self._finalize_func = per_device_dataset._finalize_func
173    self._finalize_captured_args = per_device_dataset._finalize_captured_args
174
175    variant_tensor = gen_dataset_ops.generator_dataset(
176        self._init_captured_args,
177        self._next_captured_args,
178        self._finalize_captured_args,
179        init_func=self._init_func,
180        next_func=self._next_func,
181        finalize_func=self._finalize_func,
182        **self._flat_structure)
183    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
184
185  def _inputs(self):
186    # TODO(b/116506223): Determine which datasets should be used as inputs here.
187    return []
188
189  @property
190  def element_spec(self):
191    return self._element_spec
192
193
194def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size,
195                           experimental_slack):
196  """Uses _prototype_device_datasets[i] to build a dataset for the device."""
197  ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id)
198  if prefetch_buffer_size > 0:
199    if experimental_slack:
200      ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1)
201    else:
202      ds = ds.prefetch(prefetch_buffer_size)
203  return ds
204
205
206class MultiDeviceIterator(object):
207  """An iterator over multiple devices."""
208
209  def __init__(self,
210               dataset,
211               devices,
212               max_buffer_size=1,
213               prefetch_buffer_size=1,
214               source_device="/cpu:0"):
215    """Constructs a MultiDeviceIterator.
216
217    Args:
218      dataset: The input dataset to be iterated over.
219      devices: The list of devices to fetch data to.
220      max_buffer_size: Maximum size of the host side per device buffer to keep.
221      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
222        prefetch into.
223      source_device: The host device to place the `dataset` on.  In order to
224        prevent deadlocks, if the prefetch_buffer_size is greater than the
225        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
226    """
227    options = options_lib.Options()
228    options.experimental_distribute.num_devices = len(devices)
229    dataset = dataset.with_options(options)
230    self._dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
231    self._experimental_slack = dataset.options().experimental_slack
232    self._devices = devices
233    self._source_device = source_device
234    self._source_device_tensor = ops.convert_to_tensor(source_device)
235    self._max_buffer_size = max_buffer_size
236    self._prefetch_buffer_size = prefetch_buffer_size
237
238    if self._prefetch_buffer_size > self._max_buffer_size:
239      self._max_buffer_size = self._prefetch_buffer_size
240
241    # Create the MultiDeviceIterator.
242    with ops.device(self._source_device):
243      # TODO(b/121378567): Get rid of this shared_name hack.
244      shared_name = ""
245      if context.executing_eagerly():
246        shared_name = context.shared_name()
247      self._multi_device_iterator_resource = (
248          gen_dataset_ops.multi_device_iterator(
249              devices=self._devices,
250              shared_name=shared_name,
251              container="",
252              **self._dataset._flat_structure))  # pylint: disable=protected-access
253      if context.executing_eagerly():
254        # Delete the resource when this object is deleted
255        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
256            handle=self._multi_device_iterator_resource,
257            handle_device=self._source_device)
258
259      # The incarnation ID is used to ensure consistency between the per-device
260      # iterators and the multi-device iterator.
261      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
262          self._dataset._variant_tensor,  # pylint: disable=protected-access
263          self._multi_device_iterator_resource,
264          max_buffer_size=self._max_buffer_size)
265
266    self._prototype_device_datasets = []
267    for i, device in enumerate(self._devices):
268      with ops.device(device):
269        ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
270                                 self._incarnation_id,
271                                 self._source_device_tensor,
272                                 self._dataset.element_spec)
273        self._prototype_device_datasets.append(ds)
274
275    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
276    # initialize the device side of the pipeline. This would allow the
277    # MultiDeviceIterator to choose, for example, to move some transformations
278    # into the device side from its input. It might be useful in rewriting.
279    # Create the per device iterators.
280    self._device_iterators = []
281    for i, device in enumerate(self._devices):
282      with ops.device(device):
283        ds = _create_device_dataset(self._prototype_device_datasets[i],
284                                    self._incarnation_id,
285                                    self._prefetch_buffer_size,
286                                    self._experimental_slack)
287        if context.executing_eagerly():
288          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
289        else:
290          self._device_iterators.append(
291              dataset_ops.make_initializable_iterator(ds))
292
293    if not context.executing_eagerly():
294      device_iterator_initializers = [
295          iterator.initializer for iterator in self._device_iterators
296      ]
297      self._initializer = control_flow_ops.group(*device_iterator_initializers)
298
299  def _create_device_dataset(self, i):
300    """Uses _prototype_device_datasets[i] to build a dataset for the device."""
301    ds = self._prototype_device_datasets[i]
302    ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id)
303    if self._prefetch_buffer_size > 0:
304      if self._experimental_slack:
305        ds = dataset_ops.PrefetchDataset(
306            ds, self._prefetch_buffer_size, slack_period=1)
307      else:
308        ds = ds.prefetch(self._prefetch_buffer_size)
309    return ds
310
311  def get_next(self, device=None):
312    """Returns the next element given a `device`, else returns all in a list."""
313    if device is not None:
314      index = self._devices.index(device)
315      return self._device_iterators[index].get_next()
316
317    result = []
318    for i, device in enumerate(self._devices):
319      with ops.device(device):
320        result.append(self._device_iterators[i].get_next())
321    return result
322
323  def get_next_as_optional(self):
324    result = []
325    for i, device in enumerate(self._devices):
326      with ops.device(device):
327        result.append(self._device_iterators[i].get_next_as_optional())
328    return result
329
330  @property
331  def initializer(self):
332    if context.executing_eagerly():
333      return control_flow_ops.no_op()
334    return self._initializer
335
336  def _eager_reset(self):
337    """Resets the MultiDeviceIterator in eager mode."""
338    if not ops.executing_eagerly_outside_functions():
339      raise ValueError("Eager reset is only supported in eager mode.")
340    # pylint: disable=protected-access
341    self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
342        self._dataset._variant_tensor,
343        self._multi_device_iterator_resource,
344        max_buffer_size=self._max_buffer_size)
345    for i, device in enumerate(self._devices):
346      with ops.device(device):
347        ds = _create_device_dataset(self._prototype_device_datasets[i],
348                                    self._incarnation_id,
349                                    self._prefetch_buffer_size,
350                                    self._experimental_slack)
351        # Reset the device iterator resources with the new dataset.
352        ds_variant = ds._variant_tensor
353        gen_dataset_ops.make_iterator(
354            ds_variant, self._device_iterators[i]._iterator_resource)
355
356  @property
357  def element_spec(self):
358    return self._dataset.element_spec
359
360
361class MultiDeviceIteratorResourceDeleter(object):
362  """An object which cleans up a Multi Device Iterator resource.
363
364  An alternative to defining a __del__ method on an object. Even if the parent
365  object is part of a reference cycle, the cycle will be collectible.
366  """
367
368  __slots__ = [
369      "_deleter", "_multi_device_iterator", "_iterators", "_device",
370      "_eager_mode"
371  ]
372
373  def __init__(self, multi_device_iterator, iterators, device, deleter):
374    self._deleter = deleter
375    self._multi_device_iterator = multi_device_iterator
376    self._iterators = iterators
377    self._device = device
378    self._eager_mode = context.executing_eagerly()
379
380  def __del__(self):
381    with ops.device(self._device):
382      # Make sure the resource is deleted in the same mode as it was created in.
383      # We pass in the iterator handles as inputs to the op to make sure that
384      # this op runs after all the iterators are deleted.
385      if self._eager_mode:
386        with context.eager_mode():
387          gen_dataset_ops.delete_multi_device_iterator(
388              multi_device_iterator=self._multi_device_iterator,
389              iterators=self._iterators,
390              deleter=self._deleter)
391      else:
392        with context.graph_mode():
393          gen_dataset_ops.delete_multi_device_iterator(
394              multi_device_iterator=self._multi_device_iterator,
395              iterators=self._iterators,
396              deleter=self._deleter)
397
398
399class MultiDeviceIteratorSpec(type_spec.TypeSpec):
400  """Type specification for `OwnedMultiDeviceIterator`."""
401
402  __slots__ = ["_devices", "_source_device", "_element_spec"]
403
404  def __init__(self, devices, source_device, element_spec):
405    self._devices = devices
406    self._source_device = source_device
407    self._element_spec = element_spec
408
409  @property
410  def value_type(self):
411    return OwnedMultiDeviceIterator
412
413  def _serialize(self):
414    return (tuple(self._devices), self._source_device, self._element_spec)
415
416  @property
417  def _component_specs(self):
418    specs = [
419        tensor_spec.TensorSpec([], dtypes.resource),
420        tensor_spec.TensorSpec([], dtypes.variant)
421    ]
422    for _ in range(len(self._devices)):
423      specs.append(iterator_ops.IteratorSpec(self._element_spec))
424    return specs
425
426  def _to_components(self, value):
427    # pylint: disable=protected-access
428    c = [value._multi_device_iterator_resource, value._deleter]
429    c.extend(value._device_iterators)
430    return c
431
432  def _from_components(self, components):
433    return OwnedMultiDeviceIterator(
434        dataset=None,
435        devices=self._devices,
436        source_device=self._source_device,
437        components=components,
438        element_spec=self._element_spec)
439
440  @staticmethod
441  def from_value(value):
442    # pylint: disable=protected-access
443    return MultiDeviceIteratorSpec(
444        value._devices,
445        value._source_device,
446        value.element_spec)
447
448
449class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor):
450  """An iterator over multiple devices.
451
452  The multi-device iterator resource created through `OwnedMultiDeviceIterator`
453  is owned by the Python object and the life time of the underlying resource is
454  tied to the life time of the `OwnedMultiDeviceIterator` object. This makes
455  `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of
456  tf.functions.
457  """
458
459  def __init__(self,
460               dataset=None,
461               devices=None,
462               max_buffer_size=1,
463               prefetch_buffer_size=1,
464               source_device="/cpu:0",
465               components=None,
466               element_spec=None):
467    """Constructs an owned MultiDeviceIterator object.
468
469    Args:
470      dataset: The input dataset to be iterated over.
471      devices: The list of devices to fetch data to.
472      max_buffer_size: Maximum size of the host side per device buffer to keep.
473      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
474        prefetch into.
475      source_device: The host device to place the `dataset` on.  In order to
476        prevent deadlocks, if the prefetch_buffer_size is greater than the
477        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
478      components: Tensor components to construct the MultiDeviceIterator from.
479      element_spec: A (nested) structure of `tf.TypeSpec` objects that
480        represents the type specification of elements of the iterator.
481
482    Raises:
483      RuntimeError: If executed in graph mode or outside of function building
484      mode.
485    """
486    if not context.executing_eagerly() and not ops.inside_function():
487      raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
488                         "tf.function or when eager execution is enabled.")
489    if devices is None:
490      raise ValueError("`devices` must be provided")
491    error_message = "Either `dataset` or both `components` and "
492    "`element_spec` need to be provided."
493
494    if dataset is None:
495      if (components is None or element_spec is None):
496        raise ValueError(error_message)
497      self._element_spec = element_spec
498      self._devices = devices
499      self._source_device = source_device
500      self._multi_device_iterator_resource = components[0]
501      self._deleter = components[1]
502      self._device_iterators = components[2:]
503      iterator_handles = []
504      for it in self._device_iterators:
505        iterator_handles.append(it._iterator_resource)  # pylint: disable=protected-access
506    else:
507      if (components is not None or element_spec is not None):
508        raise ValueError(error_message)
509      options = options_lib.Options()
510      options.experimental_distribute.num_devices = len(devices)
511      dataset = dataset.with_options(options)
512      dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
513      self._element_spec = dataset.element_spec
514      experimental_slack = dataset.options().experimental_slack
515      self._devices = devices
516      self._source_device = source_device
517      source_device_tensor = ops.convert_to_tensor(self._source_device)
518
519      if prefetch_buffer_size > max_buffer_size:
520        max_buffer_size = prefetch_buffer_size
521
522      # Create the MultiDeviceIterator.
523      with ops.device(self._source_device):
524        self._multi_device_iterator_resource, self._deleter = (
525            gen_dataset_ops.anonymous_multi_device_iterator(
526                devices=self._devices, **dataset._flat_structure))  # pylint: disable=protected-access
527
528        # The incarnation ID is used to ensure consistency between the
529        # per-device iterators and the multi-device iterator.
530        incarnation_id = gen_dataset_ops.multi_device_iterator_init(
531            dataset._variant_tensor,  # pylint: disable=protected-access
532            self._multi_device_iterator_resource,
533            max_buffer_size=max_buffer_size)
534
535      prototype_device_datasets = []
536      for i, device in enumerate(self._devices):
537        with ops.device(device):
538          ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
539                                   incarnation_id, source_device_tensor,
540                                   dataset.element_spec)
541          prototype_device_datasets.append(ds)
542
543      # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
544      # initialize the device side of the pipeline. This would allow the
545      # MultiDeviceIterator to choose, for example, to move some transformations
546      # into the device side from its input. It might be useful in rewriting.
547      # Create the per device iterators.
548      self._device_iterators = []
549      iterator_handles = []
550      for i, device in enumerate(self._devices):
551        with ops.device(device):
552          ds = _create_device_dataset(prototype_device_datasets[i],
553                                      incarnation_id, prefetch_buffer_size,
554                                      experimental_slack)
555          iterator = iter(ds)
556          self._device_iterators.append(iterator)
557          iterator_handles.append(iterator._iterator_resource)  # pylint: disable=protected-access
558
559      self._resource_deleter = MultiDeviceIteratorResourceDeleter(
560          multi_device_iterator=self._multi_device_iterator_resource,
561          iterators=iterator_handles,
562          device=self._source_device,
563          deleter=self._deleter)
564
565  def get_next(self, device=None):
566    """Returns the next element given a `device`, else returns all in a list."""
567    if device is not None:
568      index = self._devices.index(device)
569      return self._device_iterators[index].get_next()
570
571    result = []
572    for i, device in enumerate(self._devices):
573      with ops.device(device):
574        result.append(self._device_iterators[i].get_next())
575    return result
576
577  def __iter__(self):
578    return self
579
580  def next(self):
581    return self.__next__()
582
583  def __next__(self):
584    try:
585      return self.get_next()
586    except errors.OutOfRangeError:
587      raise StopIteration
588
589  def get_next_as_optional(self):
590    result = []
591    for i, device in enumerate(self._devices):
592      with ops.device(device):
593        result.append(self._device_iterators[i].get_next_as_optional())
594    return result
595
596  @property
597  def element_spec(self):
598    return self._element_spec
599
600  @property
601  def _type_spec(self):
602    return MultiDeviceIteratorSpec(self._devices, self._source_device,
603                                   self._element_spec)
604