• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.data.util import structure
23from tensorflow.python.eager import context
24from tensorflow.python.eager import function
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import type_spec
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import functional_ops
34from tensorflow.python.ops import gen_dataset_ops
35from tensorflow.python.ops import resource_variable_ops
36
37
38class _PerDeviceGenerator(dataset_ops.DatasetV2):
39  """A `dummy` generator dataset."""
40
41  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
42               source_device, element_spec):
43    self._element_spec = element_spec
44
45    multi_device_iterator_string_handle = (
46        gen_dataset_ops.multi_device_iterator_to_string_handle(
47            multi_device_iterator_resource))
48
49    # TODO(b/124254153): Enable autograph once the overhead is low enough.
50    @function.defun(autograph=False)  # Pure graph code.
51    def _init_func():
52      return multi_device_iterator_string_handle
53
54    init_func_concrete = _init_func.get_concrete_function()
55
56    # TODO(b/124254153): Enable autograph once the overhead is low enough.
57    @function.defun(autograph=False)  # Pure graph code.
58    def _remote_init_func():
59      return functional_ops.remote_call(
60          target=source_device,
61          args=init_func_concrete.captured_inputs,
62          Tout=[dtypes.string],
63          f=init_func_concrete)
64
65    self._init_func = _remote_init_func.get_concrete_function()
66    self._init_captured_args = self._init_func.captured_inputs
67
68    # TODO(b/124254153): Enable autograph once the overhead is low enough.
69    @function.defun(
70        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
71        autograph=False)  # Pure graph code.
72    def _next_func(string_handle):
73      # pylint: disable=protected-access
74      multi_device_iterator = (
75          gen_dataset_ops.multi_device_iterator_from_string_handle(
76              string_handle=string_handle,
77              output_types=structure.get_flat_tensor_types(self._element_spec),
78              output_shapes=structure.get_flat_tensor_shapes(
79                  self._element_spec)))
80      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
81          multi_device_iterator=multi_device_iterator,
82          shard_num=shard_num,
83          incarnation_id=incarnation_id,
84          output_types=structure.get_flat_tensor_types(self._element_spec),
85          output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
86
87    next_func_concrete = _next_func.get_concrete_function()
88
89    # TODO(b/124254153): Enable autograph once the overhead is low enough.
90    @function.defun_with_attributes(
91        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
92        attributes={"experimental_ints_on_device": True},
93        autograph=False)  # Pure graph code.
94    def _remote_next_func(string_handle):
95      return functional_ops.remote_call(
96          target=source_device,
97          args=[string_handle] + next_func_concrete.captured_inputs,
98          Tout=structure.get_flat_tensor_types(self._element_spec),
99          f=next_func_concrete)
100
101    self._next_func = _remote_next_func.get_concrete_function()
102    self._next_captured_args = self._next_func.captured_inputs
103
104    self._incarnation_id_index = -1
105    for i, arg in enumerate(self._next_captured_args):
106      if arg is incarnation_id:
107        self._incarnation_id_index = i
108
109    # TODO(b/124254153): Enable autograph once the overhead is low enough.
110    @function.defun(
111        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
112        autograph=False)  # Pure graph code.
113    def _finalize_func(unused_string_handle):
114      return array_ops.constant(0, dtypes.int64)
115
116    finalize_func_concrete = _finalize_func.get_concrete_function()
117
118    # TODO(b/124254153): Enable autograph once the overhead is low enough.
119    @function.defun(
120        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
121        autograph=False)  # Pure graph code.
122    def _remote_finalize_func(string_handle):
123      return functional_ops.remote_call(
124          target=source_device,
125          args=[string_handle] + finalize_func_concrete.captured_inputs,
126          Tout=[dtypes.int64],
127          f=finalize_func_concrete)
128
129    self._finalize_func = _remote_finalize_func.get_concrete_function()
130    self._finalize_captured_args = self._finalize_func.captured_inputs
131
132    variant_tensor = gen_dataset_ops.generator_dataset(
133        self._init_captured_args,
134        self._next_captured_args,
135        self._finalize_captured_args,
136        init_func=self._init_func,
137        next_func=self._next_func,
138        finalize_func=self._finalize_func,
139        **self._flat_structure)
140    super(_PerDeviceGenerator, self).__init__(variant_tensor)
141
142  def _inputs(self):
143    # TODO(b/116506223): Determine which datasets should be used as inputs here.
144    return []
145
146  @property
147  def element_spec(self):
148    return self._element_spec
149
150
151class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
152  """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id.
153
154  Re-uses the functions from the provided per_device_dataset and just switches
155  out the function argument corresponding to the incarnation_id.
156  """
157
158  def __init__(self, per_device_dataset, incarnation_id):
159    # pylint: disable=protected-access
160    self._element_spec = per_device_dataset.element_spec
161    self._init_func = per_device_dataset._init_func
162    self._init_captured_args = self._init_func.captured_inputs
163
164    self._next_func = per_device_dataset._next_func
165    self._next_captured_args = per_device_dataset._next_captured_args
166    # The captured arguments to the next_func are string_handle, incarnation_id.
167    # We update the incarnation id to the new one.
168    self._next_captured_args[
169        per_device_dataset._incarnation_id_index] = incarnation_id
170
171    self._finalize_func = per_device_dataset._finalize_func
172    self._finalize_captured_args = per_device_dataset._finalize_captured_args
173
174    variant_tensor = gen_dataset_ops.generator_dataset(
175        self._init_captured_args,
176        self._next_captured_args,
177        self._finalize_captured_args,
178        init_func=self._init_func,
179        next_func=self._next_func,
180        finalize_func=self._finalize_func,
181        **self._flat_structure)
182    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
183
184  def _inputs(self):
185    # TODO(b/116506223): Determine which datasets should be used as inputs here.
186    return []
187
188  @property
189  def element_spec(self):
190    return self._element_spec
191
192
193def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size,
194                           experimental_slack):
195  """Uses _prototype_device_datasets[i] to build a dataset for the device."""
196  ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id)
197  if prefetch_buffer_size > 0:
198    if experimental_slack:
199      ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1)
200    else:
201      ds = ds.prefetch(prefetch_buffer_size)
202  # TODO(jsimsa): Enable auto-tuning and optimizations when supported for
203  # non-CPU devices.
204  options = dataset_ops.Options()
205  options.experimental_optimization.apply_default_optimizations = False
206  options.experimental_optimization.autotune = False
207  ds = ds.with_options(options)
208  return ds
209
210
211class MultiDeviceIterator(object):
212  """An iterator over multiple devices."""
213
214  def __init__(self,
215               dataset,
216               devices,
217               max_buffer_size=1,
218               prefetch_buffer_size=1,
219               source_device="/cpu:0"):
220    """Constructs a MultiDeviceIterator.
221
222    Args:
223      dataset: The input dataset to be iterated over.
224      devices: The list of devices to fetch data to.
225      max_buffer_size: Maximum size of the host side per device buffer to keep.
226      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
227        prefetch into.
228      source_device: The host device to place the `dataset` on.  In order to
229        prevent deadlocks, if the prefetch_buffer_size is greater than the
230        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
231    """
232    options = dataset_ops.Options()
233    options.experimental_distribute.num_devices = len(devices)
234    dataset = dataset.with_options(options)
235    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
236    self._experimental_slack = dataset.options().experimental_slack
237    self._devices = devices
238    self._source_device = source_device
239    self._source_device_tensor = ops.convert_to_tensor(source_device)
240    self._max_buffer_size = max_buffer_size
241    self._prefetch_buffer_size = prefetch_buffer_size
242
243    if self._prefetch_buffer_size > self._max_buffer_size:
244      self._max_buffer_size = self._prefetch_buffer_size
245
246    # Create the MultiDeviceIterator.
247    with ops.device(self._source_device):
248      # TODO(b/121378567): Get rid of this shared_name hack.
249      shared_name = ""
250      if context.executing_eagerly():
251        shared_name = context.shared_name()
252      self._multi_device_iterator_resource = (
253          gen_dataset_ops.multi_device_iterator(
254              devices=self._devices,
255              shared_name=shared_name,
256              container="",
257              **self._dataset._flat_structure))  # pylint: disable=protected-access
258      if context.executing_eagerly():
259        # Delete the resource when this object is deleted
260        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
261            handle=self._multi_device_iterator_resource,
262            handle_device=self._source_device)
263
264      # The incarnation ID is used to ensure consistency between the per-device
265      # iterators and the multi-device iterator.
266      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
267          self._dataset._variant_tensor,  # pylint: disable=protected-access
268          self._multi_device_iterator_resource,
269          max_buffer_size=self._max_buffer_size)
270
271    self._prototype_device_datasets = []
272    for i, device in enumerate(self._devices):
273      with ops.device(device):
274        ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
275                                 self._incarnation_id,
276                                 self._source_device_tensor,
277                                 self._dataset.element_spec)
278        self._prototype_device_datasets.append(ds)
279
280    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
281    # initialize the device side of the pipeline. This would allow the
282    # MultiDeviceIterator to choose, for example, to move some transformations
283    # into the device side from its input. It might be useful in rewriting.
284    # Create the per device iterators.
285    self._device_iterators = []
286    for i, device in enumerate(self._devices):
287      with ops.device(device):
288        ds = _create_device_dataset(self._prototype_device_datasets[i],
289                                    self._incarnation_id,
290                                    self._prefetch_buffer_size,
291                                    self._experimental_slack)
292        if context.executing_eagerly():
293          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
294        else:
295          self._device_iterators.append(
296              dataset_ops.make_initializable_iterator(ds))
297
298    if not context.executing_eagerly():
299      device_iterator_initializers = [
300          iterator.initializer for iterator in self._device_iterators
301      ]
302      self._initializer = control_flow_ops.group(*device_iterator_initializers)
303
304  def _create_device_dataset(self, i):
305    """Uses _prototype_device_datasets[i] to build a dataset for the device."""
306    ds = self._prototype_device_datasets[i]
307    ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id)
308    if self._prefetch_buffer_size > 0:
309      if self._experimental_slack:
310        ds = dataset_ops.PrefetchDataset(
311            ds, self._prefetch_buffer_size, slack_period=1)
312      else:
313        ds = ds.prefetch(self._prefetch_buffer_size)
314    # TODO(jsimsa): Enable auto-tuning and optimizations when supported for
315    # non-CPU devices.
316    options = dataset_ops.Options()
317    options.experimental_optimization.apply_default_optimizations = False
318    options.experimental_optimization.autotune = False
319    ds = ds.with_options(options)
320    return ds
321
322  def get_next(self, device=None):
323    """Returns the next element given a `device`, else returns all in a list."""
324    if device is not None:
325      index = self._devices.index(device)
326      return self._device_iterators[index].get_next()
327
328    result = []
329    for i, device in enumerate(self._devices):
330      with ops.device(device):
331        result.append(self._device_iterators[i].get_next())
332    return result
333
334  def get_next_as_optional(self):
335    result = []
336    for i, device in enumerate(self._devices):
337      with ops.device(device):
338        result.append(self._device_iterators[i].get_next_as_optional())
339    return result
340
341  @property
342  def initializer(self):
343    if context.executing_eagerly():
344      return control_flow_ops.no_op()
345    return self._initializer
346
347  def _eager_reset(self):
348    """Resets the MultiDeviceIterator in eager mode."""
349    if not ops.executing_eagerly_outside_functions():
350      raise ValueError("Eager reset is only supported in eager mode.")
351    # pylint: disable=protected-access
352    self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
353        self._dataset._variant_tensor,
354        self._multi_device_iterator_resource,
355        max_buffer_size=self._max_buffer_size)
356    for i, device in enumerate(self._devices):
357      with ops.device(device):
358        ds = _create_device_dataset(self._prototype_device_datasets[i],
359                                    self._incarnation_id,
360                                    self._prefetch_buffer_size,
361                                    self._experimental_slack)
362        # Reset the device iterator resources with the new dataset.
363        ds_variant = ds._variant_tensor
364        gen_dataset_ops.make_iterator(
365            ds_variant, self._device_iterators[i]._iterator_resource)
366
367  @property
368  def element_spec(self):
369    return self._dataset.element_spec
370
371
372class MultiDeviceIteratorResourceDeleter(object):
373  """An object which cleans up a Multi Device Iterator resource.
374
375  An alternative to defining a __del__ method on an object. Even if the parent
376  object is part of a reference cycle, the cycle will be collectible.
377  """
378
379  __slots__ = [
380      "_deleter", "_multi_device_iterator", "_iterators", "_device",
381      "_eager_mode"
382  ]
383
384  def __init__(self, multi_device_iterator, iterators, device, deleter):
385    self._deleter = deleter
386    self._multi_device_iterator = multi_device_iterator
387    self._iterators = iterators
388    self._device = device
389    self._eager_mode = context.executing_eagerly()
390
391  def __del__(self):
392    with ops.device(self._device):
393      # Make sure the resource is deleted in the same mode as it was created in.
394      # We pass in the iterator handles as inputs to the op to make sure that
395      # this op runs after all the iterators are deleted.
396      if self._eager_mode:
397        with context.eager_mode():
398          gen_dataset_ops.delete_multi_device_iterator(
399              multi_device_iterator=self._multi_device_iterator,
400              iterators=self._iterators,
401              deleter=self._deleter)
402      else:
403        with context.graph_mode():
404          gen_dataset_ops.delete_multi_device_iterator(
405              multi_device_iterator=self._multi_device_iterator,
406              iterators=self._iterators,
407              deleter=self._deleter)
408
409
410class MultiDeviceIteratorSpec(type_spec.TypeSpec):
411  """Type specification for `OwnedMultiDeviceIterator`."""
412
413  __slots__ = ["_devices", "_source_device", "_element_spec"]
414
415  def __init__(self, devices, source_device, element_spec):
416    self._devices = devices
417    self._source_device = source_device
418    self._element_spec = element_spec
419
420  @property
421  def value_type(self):
422    return OwnedMultiDeviceIterator
423
424  def _serialize(self):
425    return (tuple(self._devices), self._source_device, self._element_spec)
426
427  @property
428  def _component_specs(self):
429    specs = [
430        tensor_spec.TensorSpec([], dtypes.resource),
431        tensor_spec.TensorSpec([], dtypes.variant)
432    ]
433    for _ in range(len(self._devices)):
434      specs.append(iterator_ops.IteratorSpec(self._element_spec))
435    return specs
436
437  def _to_components(self, value):
438    # pylint: disable=protected-access
439    c = [value._multi_device_iterator_resource, value._deleter]
440    c.extend(value._device_iterators)
441    return c
442
443  def _from_components(self, components):
444    return OwnedMultiDeviceIterator(
445        dataset=None,
446        devices=self._devices,
447        source_device=self._source_device,
448        components=components,
449        element_spec=self._element_spec)
450
451  @staticmethod
452  def from_value(value):
453    # pylint: disable=protected-access
454    return MultiDeviceIteratorSpec(
455        value._devices,
456        value._source_device,
457        value.element_spec)
458
459
460class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor):
461  """An iterator over multiple devices.
462
463  The multi-device iterator resource created through `OwnedMultiDeviceIterator`
464  is owned by the Python object and the life time of the underlying resource is
465  tied to the life time of the `OwnedMultiDeviceIterator` object. This makes
466  `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of
467  tf.functions.
468  """
469
470  def __init__(self,
471               dataset=None,
472               devices=None,
473               max_buffer_size=1,
474               prefetch_buffer_size=1,
475               source_device="/cpu:0",
476               components=None,
477               element_spec=None):
478    """Constructs an owned MultiDeviceIterator object.
479
480    Args:
481      dataset: The input dataset to be iterated over.
482      devices: The list of devices to fetch data to.
483      max_buffer_size: Maximum size of the host side per device buffer to keep.
484      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
485        prefetch into.
486      source_device: The host device to place the `dataset` on.  In order to
487        prevent deadlocks, if the prefetch_buffer_size is greater than the
488        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
489      components: Tensor components to construct the MultiDeviceIterator from.
490      element_spec: A nested structure of `TypeSpec` objects that
491        represents the type specification of elements of the iterator.
492
493    Raises:
494      RuntimeError: If executed in graph mode or outside of function building
495      mode.
496    """
497    if not context.executing_eagerly() and not ops.inside_function():
498      raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
499                         "tf.function or when eager execution is enabled.")
500    if devices is None:
501      raise ValueError("`devices` must be provided")
502    error_message = "Either `dataset` or both `components` and "
503    "`element_spec` need to be provided."
504
505    if dataset is None:
506      if (components is None or element_spec is None):
507        raise ValueError(error_message)
508      self._element_spec = element_spec
509      self._devices = devices
510      self._source_device = source_device
511      self._multi_device_iterator_resource = components[0]
512      self._deleter = components[1]
513      self._device_iterators = components[2:]
514      iterator_handles = []
515      for it in self._device_iterators:
516        iterator_handles.append(it._iterator_resource)  # pylint: disable=protected-access
517    else:
518      if (components is not None or element_spec is not None):
519        raise ValueError(error_message)
520      options = dataset_ops.Options()
521      options.experimental_distribute.num_devices = len(devices)
522      dataset = dataset.with_options(options)
523      dataset = dataset._apply_options()  # pylint: disable=protected-access
524      self._element_spec = dataset.element_spec
525      experimental_slack = dataset.options().experimental_slack
526      self._devices = devices
527      self._source_device = source_device
528      source_device_tensor = ops.convert_to_tensor(self._source_device)
529
530      if prefetch_buffer_size > max_buffer_size:
531        max_buffer_size = prefetch_buffer_size
532
533      # Create the MultiDeviceIterator.
534      with ops.device(self._source_device):
535        self._multi_device_iterator_resource, self._deleter = (
536            gen_dataset_ops.anonymous_multi_device_iterator(
537                devices=self._devices, **dataset._flat_structure))  # pylint: disable=protected-access
538
539        # The incarnation ID is used to ensure consistency between the
540        # per-device iterators and the multi-device iterator.
541        incarnation_id = gen_dataset_ops.multi_device_iterator_init(
542            dataset._variant_tensor,  # pylint: disable=protected-access
543            self._multi_device_iterator_resource,
544            max_buffer_size=max_buffer_size)
545
546      prototype_device_datasets = []
547      for i, device in enumerate(self._devices):
548        with ops.device(device):
549          ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
550                                   incarnation_id, source_device_tensor,
551                                   dataset.element_spec)
552          prototype_device_datasets.append(ds)
553
554      # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
555      # initialize the device side of the pipeline. This would allow the
556      # MultiDeviceIterator to choose, for example, to move some transformations
557      # into the device side from its input. It might be useful in rewriting.
558      # Create the per device iterators.
559      self._device_iterators = []
560      iterator_handles = []
561      for i, device in enumerate(self._devices):
562        with ops.device(device):
563          ds = _create_device_dataset(prototype_device_datasets[i],
564                                      incarnation_id, prefetch_buffer_size,
565                                      experimental_slack)
566          iterator = iter(ds)
567          self._device_iterators.append(iterator)
568          iterator_handles.append(iterator._iterator_resource)  # pylint: disable=protected-access
569
570      self._resource_deleter = MultiDeviceIteratorResourceDeleter(
571          multi_device_iterator=self._multi_device_iterator_resource,
572          iterators=iterator_handles,
573          device=self._source_device,
574          deleter=self._deleter)
575
576  def get_next(self, device=None):
577    """Returns the next element given a `device`, else returns all in a list."""
578    if device is not None:
579      index = self._devices.index(device)
580      return self._device_iterators[index].get_next()
581
582    result = []
583    for i, device in enumerate(self._devices):
584      with ops.device(device):
585        result.append(self._device_iterators[i].get_next())
586    return result
587
588  def __iter__(self):
589    return self
590
591  def next(self):
592    return self.__next__()
593
594  def __next__(self):
595    try:
596      return self.get_next()
597    except errors.OutOfRangeError:
598      raise StopIteration
599
600  def get_next_as_optional(self):
601    result = []
602    for i, device in enumerate(self._devices):
603      with ops.device(device):
604        result.append(self._device_iterators[i].get_next_as_optional())
605    return result
606
607  @property
608  def element_spec(self):
609    return self._element_spec
610
611  @property
612  def _type_spec(self):
613    return MultiDeviceIteratorSpec(self._devices, self._source_device,
614                                   self._element_spec)
615