1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Non-deterministic dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python import tf2 21from tensorflow.python.data.experimental.ops import random_ops 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.ops import readers 24from tensorflow.python.data.util import nest 25from tensorflow.python.data.util import structure 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gen_experimental_dataset_ops 31from tensorflow.python.ops import gen_stateless_random_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.util import deprecation 34from tensorflow.python.util.tf_export import tf_export 35 36 37@deprecation.deprecated( 38 None, 39 "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, " 40 "num_parallel_calls=tf.data.AUTOTUNE)` instead. If sloppy " 41 "execution is desired, use `tf.data.Options.experimental_deterministic`.") 42@tf_export("data.experimental.parallel_interleave") 43def parallel_interleave(map_func, 44 cycle_length, 45 block_length=1, 46 sloppy=False, 47 buffer_output_elements=None, 48 prefetch_input_elements=None): 49 """A parallel version of the `Dataset.interleave()` transformation. 50 51 `parallel_interleave()` maps `map_func` across its input to produce nested 52 datasets, and outputs their elements interleaved. Unlike 53 `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested 54 datasets in parallel, which increases the throughput, especially in the 55 presence of stragglers. Furthermore, the `sloppy` argument can be used to 56 improve performance, by relaxing the requirement that the outputs are produced 57 in a deterministic order, and allowing the implementation to skip over nested 58 datasets whose elements are not readily available when requested. 59 60 Example usage: 61 62 ```python 63 # Preprocess 4 files concurrently. 64 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 65 dataset = filenames.apply( 66 tf.data.experimental.parallel_interleave( 67 lambda filename: tf.data.TFRecordDataset(filename), 68 cycle_length=4)) 69 ``` 70 71 WARNING: If `sloppy` is `True`, the order of produced elements is not 72 deterministic. 73 74 Args: 75 map_func: A function mapping a nested structure of tensors to a `Dataset`. 76 cycle_length: The number of input `Dataset`s to interleave from in parallel. 77 block_length: The number of consecutive elements to pull from an input 78 `Dataset` before advancing to the next input `Dataset`. 79 sloppy: A boolean controlling whether determinism should be traded for 80 performance by allowing elements to be produced out of order. If 81 `sloppy` is `None`, the `tf.data.Options.experimental_deterministic` 82 dataset option (`True` by default) is used to decide whether to enforce a 83 deterministic order. 84 buffer_output_elements: The number of elements each iterator being 85 interleaved should buffer (similar to the `.prefetch()` transformation for 86 each interleaved iterator). 87 prefetch_input_elements: The number of input elements to transform to 88 iterators before they are needed for interleaving. 89 90 Returns: 91 A `Dataset` transformation function, which can be passed to 92 `tf.data.Dataset.apply`. 93 """ 94 95 def _apply_fn(dataset): 96 return readers.ParallelInterleaveDataset(dataset, map_func, cycle_length, 97 block_length, sloppy, 98 buffer_output_elements, 99 prefetch_input_elements) 100 101 return _apply_fn 102 103 104class _DirectedInterleaveDataset(dataset_ops.DatasetV2): 105 """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" 106 107 def __init__(self, selector_input, data_inputs): 108 self._selector_input = selector_input 109 self._data_inputs = list(data_inputs) 110 111 first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0]) 112 first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0]) 113 114 for i, data_input in enumerate(data_inputs[1:]): 115 if (dataset_ops.get_legacy_output_types(data_input) != first_output_types 116 or dataset_ops.get_legacy_output_classes(data_input) 117 != first_output_classes): 118 raise TypeError("All datasets must have the same type and class.\n" 119 "dataset 0 vs dataset %s types: %s ; %s\n" 120 "classes: %s ; %s" % 121 (i + 1, first_output_types, 122 dataset_ops.get_legacy_output_types(data_input), 123 first_output_classes, 124 dataset_ops.get_legacy_output_classes(data_input))) 125 126 output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0]) 127 for data_input in self._data_inputs[1:]: 128 output_shapes = nest.pack_sequence_as(output_shapes, [ 129 ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( 130 nest.flatten(output_shapes), 131 nest.flatten(dataset_ops.get_legacy_output_shapes(data_input))) 132 ]) 133 134 self._element_spec = structure.convert_legacy_structure( 135 first_output_types, output_shapes, first_output_classes) 136 # pylint: disable=protected-access 137 variant_tensor = gen_experimental_dataset_ops.directed_interleave_dataset( 138 self._selector_input._variant_tensor, 139 [data_input._variant_tensor for data_input in self._data_inputs], 140 **self._flat_structure) 141 super(_DirectedInterleaveDataset, self).__init__(variant_tensor) 142 143 def _inputs(self): 144 return [self._selector_input] + self._data_inputs 145 146 @property 147 def element_spec(self): 148 return self._element_spec 149 150 151@tf_export("data.experimental.sample_from_datasets", v1=[]) 152def sample_from_datasets_v2(datasets, weights=None, seed=None): 153 """Samples elements at random from the datasets in `datasets`. 154 155 Args: 156 datasets: A list of `tf.data.Dataset` objects with compatible structure. 157 weights: (Optional.) A list of `len(datasets)` floating-point values where 158 `weights[i]` represents the probability with which an element should be 159 sampled from `datasets[i]`, or a `tf.data.Dataset` object where each 160 element is such a list. Defaults to a uniform distribution across 161 `datasets`. 162 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 163 random seed that will be used to create the distribution. See 164 `tf.random.set_seed` for behavior. 165 166 Returns: 167 A dataset that interleaves elements from `datasets` at random, according to 168 `weights` if provided, otherwise with uniform probability. 169 170 Raises: 171 TypeError: If the `datasets` or `weights` arguments have the wrong type. 172 ValueError: If the `weights` argument is specified and does not match the 173 length of the `datasets` element. 174 """ 175 num_datasets = len(datasets) 176 if not isinstance(weights, dataset_ops.DatasetV2): 177 if weights is None: 178 # Select inputs with uniform probability. 179 logits = [[1.0] * num_datasets] 180 181 else: 182 # Use the given `weights` as the probability of choosing the respective 183 # input. 184 weights = ops.convert_to_tensor(weights, name="weights") 185 if weights.dtype not in (dtypes.float32, dtypes.float64): 186 raise TypeError("`weights` must be convertible to a tensor of " 187 "`tf.float32` or `tf.float64` elements.") 188 if not weights.shape.is_compatible_with([num_datasets]): 189 raise ValueError( 190 "`weights` must be a vector of length `len(datasets)`.") 191 192 # The `stateless_multinomial()` op expects log-probabilities, as opposed 193 # to weights. 194 logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) 195 196 # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it 197 # is a `Dataset`, it is possible that evaluating it has a side effect the 198 # user depends on. 199 if len(datasets) == 1: 200 return datasets[0] 201 202 def select_dataset_constant_logits(seed): 203 return array_ops.squeeze( 204 gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), 205 axis=[0, 1]) 206 207 selector_input = dataset_ops.MapDataset( 208 random_ops.RandomDataset(seed).batch(2), 209 select_dataset_constant_logits, 210 use_inter_op_parallelism=False) 211 212 else: 213 # Use each element of the given `weights` dataset as the probability of 214 # choosing the respective input. 215 216 # The `stateless_multinomial()` op expects log-probabilities, as opposed to 217 # weights. 218 logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) 219 220 def select_dataset_varying_logits(logits, seed): 221 return array_ops.squeeze( 222 gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), 223 axis=[0, 1]) 224 225 logits_and_seeds = dataset_ops.Dataset.zip( 226 (logits_ds, random_ops.RandomDataset(seed).batch(2))) 227 selector_input = dataset_ops.MapDataset( 228 logits_and_seeds, 229 select_dataset_varying_logits, 230 use_inter_op_parallelism=False) 231 232 return _DirectedInterleaveDataset(selector_input, datasets) 233 234 235@tf_export(v1=["data.experimental.sample_from_datasets"]) 236def sample_from_datasets_v1(datasets, weights=None, seed=None): 237 return dataset_ops.DatasetV1Adapter( 238 sample_from_datasets_v2(datasets, weights, seed)) 239sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__ 240 241 242@tf_export("data.experimental.choose_from_datasets", v1=[]) 243def choose_from_datasets_v2(datasets, choice_dataset): 244 """Creates a dataset that deterministically chooses elements from `datasets`. 245 246 For example, given the following datasets: 247 248 ```python 249 datasets = [tf.data.Dataset.from_tensors("foo").repeat(), 250 tf.data.Dataset.from_tensors("bar").repeat(), 251 tf.data.Dataset.from_tensors("baz").repeat()] 252 253 # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. 254 choice_dataset = tf.data.Dataset.range(3).repeat(3) 255 256 result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) 257 ``` 258 259 The elements of `result` will be: 260 261 ``` 262 "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" 263 ``` 264 265 Args: 266 datasets: A list of `tf.data.Dataset` objects with compatible structure. 267 choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between 268 `0` and `len(datasets) - 1`. 269 270 Returns: 271 A dataset that interleaves elements from `datasets` according to the values 272 of `choice_dataset`. 273 274 Raises: 275 TypeError: If the `datasets` or `choice_dataset` arguments have the wrong 276 type. 277 """ 278 if not structure.are_compatible(choice_dataset.element_spec, 279 tensor_spec.TensorSpec([], dtypes.int64)): 280 raise TypeError("`choice_dataset` must be a dataset of scalar " 281 "`tf.int64` tensors.") 282 return _DirectedInterleaveDataset(choice_dataset, datasets) 283 284 285@tf_export(v1=["data.experimental.choose_from_datasets"]) 286def choose_from_datasets_v1(datasets, choice_dataset): 287 return dataset_ops.DatasetV1Adapter( 288 choose_from_datasets_v2(datasets, choice_dataset)) 289choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__ 290 291 292if tf2.enabled(): 293 choose_from_datasets = choose_from_datasets_v2 294 sample_from_datasets = sample_from_datasets_v2 295else: 296 choose_from_datasets = choose_from_datasets_v1 297 sample_from_datasets = sample_from_datasets_v1 298