• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Non-deterministic dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python import tf2
21from tensorflow.python.data.experimental.ops import random_ops
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops import readers
24from tensorflow.python.data.util import nest
25from tensorflow.python.data.util import structure
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gen_experimental_dataset_ops
31from tensorflow.python.ops import gen_stateless_random_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.util import deprecation
34from tensorflow.python.util.tf_export import tf_export
35
36
37@deprecation.deprecated(
38    None,
39    "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
40    "num_parallel_calls=tf.data.AUTOTUNE)` instead. If sloppy "
41    "execution is desired, use `tf.data.Options.deterministic`.")
42@tf_export("data.experimental.parallel_interleave")
43def parallel_interleave(map_func,
44                        cycle_length,
45                        block_length=1,
46                        sloppy=False,
47                        buffer_output_elements=None,
48                        prefetch_input_elements=None):
49  """A parallel version of the `Dataset.interleave()` transformation.
50
51  `parallel_interleave()` maps `map_func` across its input to produce nested
52  datasets, and outputs their elements interleaved. Unlike
53  `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
54  datasets in parallel, which increases the throughput, especially in the
55  presence of stragglers. Furthermore, the `sloppy` argument can be used to
56  improve performance, by relaxing the requirement that the outputs are produced
57  in a deterministic order, and allowing the implementation to skip over nested
58  datasets whose elements are not readily available when requested.
59
60  Example usage:
61
62  ```python
63  # Preprocess 4 files concurrently.
64  filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
65  dataset = filenames.apply(
66      tf.data.experimental.parallel_interleave(
67          lambda filename: tf.data.TFRecordDataset(filename),
68          cycle_length=4))
69  ```
70
71  WARNING: If `sloppy` is `True`, the order of produced elements is not
72  deterministic.
73
74  Args:
75    map_func: A function mapping a nested structure of tensors to a `Dataset`.
76    cycle_length: The number of input `Dataset`s to interleave from in parallel.
77    block_length: The number of consecutive elements to pull from an input
78      `Dataset` before advancing to the next input `Dataset`.
79    sloppy: A boolean controlling whether determinism should be traded for
80      performance by allowing elements to be produced out of order.  If `sloppy`
81      is `None`, the `tf.data.Options.deterministic` dataset option (`True` by
82      default) is used to decide whether to enforce a deterministic order.
83    buffer_output_elements: The number of elements each iterator being
84      interleaved should buffer (similar to the `.prefetch()` transformation for
85      each interleaved iterator).
86    prefetch_input_elements: The number of input elements to transform to
87      iterators before they are needed for interleaving.
88
89  Returns:
90    A `Dataset` transformation function, which can be passed to
91    `tf.data.Dataset.apply`.
92  """
93
94  def _apply_fn(dataset):
95    return readers.ParallelInterleaveDataset(dataset, map_func, cycle_length,
96                                             block_length, sloppy,
97                                             buffer_output_elements,
98                                             prefetch_input_elements)
99
100  return _apply_fn
101
102
103class _DirectedInterleaveDataset(dataset_ops.DatasetV2):
104  """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
105
106  def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False):
107    self._selector_input = selector_input
108    self._data_inputs = list(data_inputs)
109    self._stop_on_empty_dataset = stop_on_empty_dataset
110
111    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
112    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])
113
114    for i, data_input in enumerate(data_inputs[1:]):
115      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
116          or dataset_ops.get_legacy_output_classes(data_input) !=
117          first_output_classes):
118        raise TypeError("All datasets must have the same type and class.\n"
119                        "dataset 0 vs dataset %s types: %s ; %s\n"
120                        "classes: %s ; %s" %
121                        (i + 1, first_output_types,
122                         dataset_ops.get_legacy_output_types(data_input),
123                         first_output_classes,
124                         dataset_ops.get_legacy_output_classes(data_input)))
125
126    spec = self._data_inputs[0].element_spec
127    for data_input in self._data_inputs[1:]:
128      spec = nest.pack_sequence_as(spec, [
129          x.most_specific_compatible_type(y) for (x, y) in zip(
130              nest.flatten(spec),
131              nest.flatten(data_input.element_spec))
132      ])
133    self._element_spec = spec
134
135    # pylint: disable=protected-access
136    variant_tensor = (
137        gen_experimental_dataset_ops.directed_interleave_dataset(
138            self._selector_input._variant_tensor,
139            [data_input._variant_tensor for data_input in self._data_inputs],
140            stop_on_empty_dataset=self._stop_on_empty_dataset,
141            **self._flat_structure))
142
143    super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
144
145  def _inputs(self):
146    return [self._selector_input] + self._data_inputs
147
148  @property
149  def element_spec(self):
150    return self._element_spec
151
152
153@tf_export("data.experimental.sample_from_datasets", v1=[])
154def sample_from_datasets_v2(datasets,
155                            weights=None,
156                            seed=None,
157                            stop_on_empty_dataset=False):
158  """Samples elements at random from the datasets in `datasets`.
159
160  Creates a dataset by interleaving elements of `datasets` with `weight[i]`
161  probability of picking an element from dataset `i`. Sampling is done without
162  replacement. For example, suppose we have 2 datasets:
163
164  ```python
165  dataset1 = tf.data.Dataset.range(0, 3)
166  dataset2 = tf.data.Dataset.range(100, 103)
167  ```
168
169  Suppose also that we sample from these 2 datasets with the following weights:
170
171  ```python
172  sample_dataset = tf.data.experimental.sample_from_datasets(
173      [dataset1, dataset2], weights=[0.5, 0.5])
174  ```
175
176  One possible outcome of elements in sample_dataset is:
177
178  ```
179  print(list(sample_dataset.as_numpy_iterator()))
180  # [100, 0, 1, 101, 2, 102]
181  ```
182
183  Args:
184    datasets: A non-empty list of `tf.data.Dataset` objects with compatible
185      structure.
186    weights: (Optional.) A list or Tensor of `len(datasets)` floating-point
187      values where `weights[i]` represents the probability to sample from
188      `datasets[i]`, or a `tf.data.Dataset` object where each element is such a
189      list. Defaults to a uniform distribution across `datasets`.
190    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
191      seed that will be used to create the distribution. See
192      `tf.random.set_seed` for behavior.
193    stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty
194      dataset. If `False`, it skips empty datasets. It is recommended to set it
195      to `True`. Otherwise, the distribution of samples starts off as the user
196      intends, but may change as input datasets become empty. This can be
197      difficult to detect since the dataset starts off looking correct. Default
198      to `False` for backward compatibility.
199
200  Returns:
201    A dataset that interleaves elements from `datasets` at random, according to
202    `weights` if provided, otherwise with uniform probability.
203
204  Raises:
205    TypeError: If the `datasets` or `weights` arguments have the wrong type.
206    ValueError:
207      - If `datasets` is empty, or
208      - If `weights` is specified and does not match the length of `datasets`.
209  """
210  def _shapes_are_compatible(datasets, weights):
211    if isinstance(weights, ops.Tensor):
212      return weights.shape.is_compatible_with([len(datasets)])
213    return len(datasets) == len(weights)
214
215  def _skip_datasets_with_zero_weight(datasets, weights):
216    datasets_and_weights = [(dataset, weight)
217                            for (dataset, weight) in zip(datasets, weights)
218                            if weight > 0]
219    return (zip(*datasets_and_weights) if datasets_and_weights else
220            ([datasets[0].take(0)], [1.]))
221
222  if not datasets:
223    raise ValueError("`datasets` must be a non-empty list of datasets.")
224
225  if not isinstance(weights, dataset_ops.DatasetV2):
226    if weights is None:
227      # Select inputs with uniform probability.
228      logits = [[1.0] * len(datasets)]
229
230    else:
231      if not _shapes_are_compatible(datasets, weights):
232        raise ValueError("`weights` must have the same length as `datasets`.")
233
234      # Use the given `weights` as the probability of choosing the respective
235      # input.
236      if not isinstance(weights, ops.Tensor):
237        datasets, weights = _skip_datasets_with_zero_weight(datasets, weights)
238      weights = ops.convert_to_tensor(weights, name="weights")
239      if weights.dtype not in (dtypes.float32, dtypes.float64):
240        raise TypeError("`weights` must be convertible to a tensor of "
241                        "`tf.float32` or `tf.float64` elements.")
242
243      # The `stateless_multinomial()` op expects log-probabilities, as opposed
244      # to weights.
245      logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
246
247    # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
248    # is a `Dataset`, it is possible that evaluating it has a side effect the
249    # user depends on.
250    if len(datasets) == 1:
251      return datasets[0]
252
253    def select_dataset_constant_logits(seed):
254      return array_ops.squeeze(
255          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
256          axis=[0, 1])
257
258    selector_input = dataset_ops.MapDataset(
259        random_ops.RandomDataset(seed).batch(2),
260        select_dataset_constant_logits,
261        use_inter_op_parallelism=False)
262
263  else:
264    # Use each element of the given `weights` dataset as the probability of
265    # choosing the respective input.
266    #
267    # The `stateless_multinomial()` op expects log-probabilities, as opposed to
268    # weights.
269    logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
270
271    def select_dataset_varying_logits(logits, seed):
272      return array_ops.squeeze(
273          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
274          axis=[0, 1])
275
276    logits_and_seeds = dataset_ops.Dataset.zip(
277        (logits_ds, random_ops.RandomDataset(seed).batch(2)))
278    selector_input = dataset_ops.MapDataset(
279        logits_and_seeds,
280        select_dataset_varying_logits,
281        use_inter_op_parallelism=False)
282
283  return _DirectedInterleaveDataset(selector_input, datasets,
284                                    stop_on_empty_dataset)
285
286
287@tf_export(v1=["data.experimental.sample_from_datasets"])
288def sample_from_datasets_v1(datasets,
289                            weights=None,
290                            seed=None,
291                            stop_on_empty_dataset=False):
292  return dataset_ops.DatasetV1Adapter(
293      sample_from_datasets_v2(datasets, weights, seed, stop_on_empty_dataset))
294
295
296sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__
297
298
299@tf_export("data.experimental.choose_from_datasets", v1=[])
300def choose_from_datasets_v2(datasets,
301                            choice_dataset,
302                            stop_on_empty_dataset=False):
303  """Creates a dataset that deterministically chooses elements from `datasets`.
304
305  For example, given the following datasets:
306
307  ```python
308  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
309              tf.data.Dataset.from_tensors("bar").repeat(),
310              tf.data.Dataset.from_tensors("baz").repeat()]
311
312  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
313  choice_dataset = tf.data.Dataset.range(3).repeat(3)
314
315  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
316  ```
317
318  The elements of `result` will be:
319
320  ```
321  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
322  ```
323
324  Args:
325    datasets: A non-empty list of `tf.data.Dataset` objects with compatible
326      structure.
327    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0`
328      and `len(datasets) - 1`.
329    stop_on_empty_dataset: If `True`, selection stops if it encounters an empty
330      dataset. If `False`, it skips empty datasets. It is recommended to set it
331      to `True`. Otherwise, the selected elements start off as the user intends,
332      but may change as input datasets become empty. This can be difficult to
333      detect since the dataset starts off looking correct. Default to `False`
334      for backward compatibility.
335
336  Returns:
337    A dataset that interleaves elements from `datasets` according to the values
338    of `choice_dataset`.
339
340  Raises:
341    TypeError: If `datasets` or `choice_dataset` has the wrong type.
342    ValueError: If `datasets` is empty.
343  """
344  if not datasets:
345    raise ValueError("`datasets` must be a non-empty list of datasets.")
346  if choice_dataset is None or not structure.are_compatible(
347      choice_dataset.element_spec, tensor_spec.TensorSpec([], dtypes.int64)):
348    raise TypeError("`choice_dataset` must be a dataset of scalar "
349                    "`tf.int64` tensors.")
350  return _DirectedInterleaveDataset(choice_dataset, datasets,
351                                    stop_on_empty_dataset)
352
353
354@tf_export(v1=["data.experimental.choose_from_datasets"])
355def choose_from_datasets_v1(datasets,
356                            choice_dataset,
357                            stop_on_empty_dataset=False):
358  return dataset_ops.DatasetV1Adapter(
359      choose_from_datasets_v2(datasets, choice_dataset, stop_on_empty_dataset))
360
361
362choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
363
364if tf2.enabled():
365  choose_from_datasets = choose_from_datasets_v2
366  sample_from_datasets = sample_from_datasets_v2
367else:
368  choose_from_datasets = choose_from_datasets_v1
369  sample_from_datasets = sample_from_datasets_v1
370