• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Non-deterministic dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import convert
22from tensorflow.python.data.util import nest
23from tensorflow.python.data.util import sparse
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import gen_dataset_ops
28from tensorflow.python.util import deprecation
29
30
31class ParallelInterleaveDataset(dataset_ops.Dataset):
32  """A `Dataset` that maps a function over its input and flattens the result."""
33
34  def __init__(self, input_dataset, map_func, cycle_length, block_length,
35               sloppy, buffer_output_elements, prefetch_input_elements):
36    """See `tf.contrib.data.parallel_interleave()` for details."""
37    super(ParallelInterleaveDataset, self).__init__()
38    self._input_dataset = input_dataset
39
40    @function.Defun(*nest.flatten(
41        sparse.as_dense_types(input_dataset.output_types,
42                              input_dataset.output_classes)))
43    def tf_map_func(*args):
44      """A wrapper for Defun that facilitates shape inference."""
45      # Pass in shape information from the input_dataset.
46      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
47                                            input_dataset.output_classes)
48      for arg, shape in zip(args, nest.flatten(dense_shapes)):
49        arg.set_shape(shape)
50
51      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
52      nested_args = sparse.deserialize_sparse_tensors(
53          nested_args, input_dataset.output_types, input_dataset.output_shapes,
54          input_dataset.output_classes)
55      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
56        dataset = map_func(*nested_args)
57      else:
58        dataset = map_func(nested_args)
59
60      if not isinstance(dataset, dataset_ops.Dataset):
61        raise TypeError("`map_func` must return a `Dataset` object.")
62
63      self._output_classes = dataset.output_classes
64      self._output_types = dataset.output_types
65      self._output_shapes = dataset.output_shapes
66
67      return dataset._as_variant_tensor()  # pylint: disable=protected-access
68
69    self._map_func = tf_map_func
70    self._map_func.add_to_graph(ops.get_default_graph())
71
72    self._cycle_length = ops.convert_to_tensor(
73        cycle_length, dtype=dtypes.int64, name="cycle_length")
74    self._block_length = ops.convert_to_tensor(
75        block_length, dtype=dtypes.int64, name="block_length")
76    self._sloppy = ops.convert_to_tensor(
77        sloppy, dtype=dtypes.bool, name="sloppy")
78    self._buffer_output_elements = convert.optional_param_to_tensor(
79        "buffer_output_elements",
80        buffer_output_elements,
81        argument_default=2 * block_length)
82    self._prefetch_input_elements = convert.optional_param_to_tensor(
83        "prefetch_input_elements",
84        prefetch_input_elements,
85        argument_default=2 * cycle_length)
86
87  def _as_variant_tensor(self):
88    return gen_dataset_ops.parallel_interleave_dataset(
89        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
90        self._map_func.captured_inputs,
91        self._cycle_length,
92        self._block_length,
93        self._sloppy,
94        self._buffer_output_elements,
95        self._prefetch_input_elements,
96        f=self._map_func,
97        output_types=nest.flatten(
98            sparse.as_dense_types(self.output_types, self.output_classes)),
99        output_shapes=nest.flatten(
100            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
101
102  @property
103  def output_classes(self):
104    return self._output_classes
105
106  @property
107  def output_shapes(self):
108    return self._output_shapes
109
110  @property
111  def output_types(self):
112    return self._output_types
113
114
115def parallel_interleave(map_func,
116                        cycle_length,
117                        block_length=1,
118                        sloppy=False,
119                        buffer_output_elements=None,
120                        prefetch_input_elements=None):
121  """A parallel version of the `Dataset.interleave()` transformation.
122
123  `parallel_interleave()` maps `map_func` across its input to produce nested
124  datasets, and outputs their elements interleaved. Unlike
125  @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested
126  datasets in parallel, which increases the throughput, especially in the
127  presence of stragglers. Furthermore, the `sloppy` argument can be used to
128  improve performance, by relaxing the requirement that the outputs are produced
129  in a deterministic order, and allowing the implementation to skip over nested
130  datasets whose elements are not readily available when requested.
131
132  Example usage:
133
134  ```python
135  # Preprocess 4 files concurrently.
136  filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
137  dataset = filenames.apply(
138      tf.contrib.data.parallel_interleave(
139          lambda filename: tf.data.TFRecordDataset(filename),
140          cycle_length=4))
141  ```
142
143  WARNING: If `sloppy` is `True`, the order of produced elements is not
144  deterministic.
145
146  Args:
147    map_func: A function mapping a nested structure of tensors to a `Dataset`.
148    cycle_length: The number of input `Dataset`s to interleave from in parallel.
149    block_length: The number of consecutive elements to pull from an input
150      `Dataset` before advancing to the next input `Dataset`.
151    sloppy: If false, elements are produced in deterministic order. Otherwise,
152      the implementation is allowed, for the sake of expediency, to produce
153      elements in a non-deterministic order.
154    buffer_output_elements: The number of elements each iterator being
155      interleaved should buffer (similar to the `.prefetch()` transformation for
156      each interleaved iterator).
157    prefetch_input_elements: The number of input elements to transform to
158      iterators before they are needed for interleaving.
159
160  Returns:
161    A `Dataset` transformation function, which can be passed to
162    @{tf.data.Dataset.apply}.
163  """
164  def _apply_fn(dataset):
165    return ParallelInterleaveDataset(
166        dataset, map_func, cycle_length, block_length, sloppy,
167        buffer_output_elements, prefetch_input_elements)
168
169  return _apply_fn
170
171
172@deprecation.deprecated(
173    None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.")
174def sloppy_interleave(map_func, cycle_length, block_length=1):
175  """A non-deterministic version of the `Dataset.interleave()` transformation.
176
177  `sloppy_interleave()` maps `map_func` across `dataset`, and
178  non-deterministically interleaves the results.
179
180  The resulting dataset is almost identical to `interleave`. The key
181  difference is that if retrieving a value from a given output iterator would
182  cause `get_next` to block, that iterator will be skipped, and consumed
183  when next available. If consuming from all iterators would cause the
184  `get_next` call to block, the `get_next` call blocks until the first value is
185  available.
186
187  If the underlying datasets produce elements as fast as they are consumed, the
188  `sloppy_interleave` transformation behaves identically to `interleave`.
189  However, if an underlying dataset would block the consumer,
190  `sloppy_interleave` can violate the round-robin order (that `interleave`
191  strictly obeys), producing an element from a different underlying
192  dataset instead.
193
194  Example usage:
195
196  ```python
197  # Preprocess 4 files concurrently.
198  filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
199  dataset = filenames.apply(
200      tf.contrib.data.sloppy_interleave(
201          lambda filename: tf.data.TFRecordDataset(filename),
202          cycle_length=4))
203  ```
204
205  WARNING: The order of elements in the resulting dataset is not
206  deterministic. Use `Dataset.interleave()` if you want the elements to have a
207  deterministic order.
208
209  Args:
210    map_func: A function mapping a nested structure of tensors (having shapes
211      and types defined by `self.output_shapes` and `self.output_types`) to a
212      `Dataset`.
213    cycle_length: The number of input `Dataset`s to interleave from in parallel.
214    block_length: The number of consecutive elements to pull from an input
215      `Dataset` before advancing to the next input `Dataset`. Note:
216      `sloppy_interleave` will skip the remainder of elements in the
217      `block_length` in order to avoid blocking.
218
219  Returns:
220    A `Dataset` transformation function, which can be passed to
221    @{tf.data.Dataset.apply}.
222  """
223  def _apply_fn(dataset):
224    return ParallelInterleaveDataset(
225        dataset,
226        map_func,
227        cycle_length,
228        block_length,
229        sloppy=True,
230        buffer_output_elements=None,
231        prefetch_input_elements=None)
232
233  return _apply_fn
234