1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Non-deterministic dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python import tf2 21from tensorflow.python.data.experimental.ops import random_ops 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.ops import readers 24from tensorflow.python.data.util import nest 25from tensorflow.python.data.util import structure 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gen_experimental_dataset_ops 31from tensorflow.python.ops import gen_stateless_random_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.util import deprecation 34from tensorflow.python.util.tf_export import tf_export 35 36 37@deprecation.deprecated( 38 None, 39 "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, " 40 "num_parallel_calls=tf.data.AUTOTUNE)` instead. If sloppy " 41 "execution is desired, use `tf.data.Options.deterministic`.") 42@tf_export("data.experimental.parallel_interleave") 43def parallel_interleave(map_func, 44 cycle_length, 45 block_length=1, 46 sloppy=False, 47 buffer_output_elements=None, 48 prefetch_input_elements=None): 49 """A parallel version of the `Dataset.interleave()` transformation. 50 51 `parallel_interleave()` maps `map_func` across its input to produce nested 52 datasets, and outputs their elements interleaved. Unlike 53 `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested 54 datasets in parallel, which increases the throughput, especially in the 55 presence of stragglers. Furthermore, the `sloppy` argument can be used to 56 improve performance, by relaxing the requirement that the outputs are produced 57 in a deterministic order, and allowing the implementation to skip over nested 58 datasets whose elements are not readily available when requested. 59 60 Example usage: 61 62 ```python 63 # Preprocess 4 files concurrently. 64 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 65 dataset = filenames.apply( 66 tf.data.experimental.parallel_interleave( 67 lambda filename: tf.data.TFRecordDataset(filename), 68 cycle_length=4)) 69 ``` 70 71 WARNING: If `sloppy` is `True`, the order of produced elements is not 72 deterministic. 73 74 Args: 75 map_func: A function mapping a nested structure of tensors to a `Dataset`. 76 cycle_length: The number of input `Dataset`s to interleave from in parallel. 77 block_length: The number of consecutive elements to pull from an input 78 `Dataset` before advancing to the next input `Dataset`. 79 sloppy: A boolean controlling whether determinism should be traded for 80 performance by allowing elements to be produced out of order. If `sloppy` 81 is `None`, the `tf.data.Options.deterministic` dataset option (`True` by 82 default) is used to decide whether to enforce a deterministic order. 83 buffer_output_elements: The number of elements each iterator being 84 interleaved should buffer (similar to the `.prefetch()` transformation for 85 each interleaved iterator). 86 prefetch_input_elements: The number of input elements to transform to 87 iterators before they are needed for interleaving. 88 89 Returns: 90 A `Dataset` transformation function, which can be passed to 91 `tf.data.Dataset.apply`. 92 """ 93 94 def _apply_fn(dataset): 95 return readers.ParallelInterleaveDataset(dataset, map_func, cycle_length, 96 block_length, sloppy, 97 buffer_output_elements, 98 prefetch_input_elements) 99 100 return _apply_fn 101 102 103class _DirectedInterleaveDataset(dataset_ops.DatasetV2): 104 """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" 105 106 def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False): 107 self._selector_input = selector_input 108 self._data_inputs = list(data_inputs) 109 self._stop_on_empty_dataset = stop_on_empty_dataset 110 111 first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0]) 112 first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0]) 113 114 for i, data_input in enumerate(data_inputs[1:]): 115 if (dataset_ops.get_legacy_output_types(data_input) != first_output_types 116 or dataset_ops.get_legacy_output_classes(data_input) != 117 first_output_classes): 118 raise TypeError("All datasets must have the same type and class.\n" 119 "dataset 0 vs dataset %s types: %s ; %s\n" 120 "classes: %s ; %s" % 121 (i + 1, first_output_types, 122 dataset_ops.get_legacy_output_types(data_input), 123 first_output_classes, 124 dataset_ops.get_legacy_output_classes(data_input))) 125 126 spec = self._data_inputs[0].element_spec 127 for data_input in self._data_inputs[1:]: 128 spec = nest.pack_sequence_as(spec, [ 129 x.most_specific_compatible_type(y) for (x, y) in zip( 130 nest.flatten(spec), 131 nest.flatten(data_input.element_spec)) 132 ]) 133 self._element_spec = spec 134 135 # pylint: disable=protected-access 136 variant_tensor = ( 137 gen_experimental_dataset_ops.directed_interleave_dataset( 138 self._selector_input._variant_tensor, 139 [data_input._variant_tensor for data_input in self._data_inputs], 140 stop_on_empty_dataset=self._stop_on_empty_dataset, 141 **self._flat_structure)) 142 143 super(_DirectedInterleaveDataset, self).__init__(variant_tensor) 144 145 def _inputs(self): 146 return [self._selector_input] + self._data_inputs 147 148 @property 149 def element_spec(self): 150 return self._element_spec 151 152 153@tf_export("data.experimental.sample_from_datasets", v1=[]) 154def sample_from_datasets_v2(datasets, 155 weights=None, 156 seed=None, 157 stop_on_empty_dataset=False): 158 """Samples elements at random from the datasets in `datasets`. 159 160 Creates a dataset by interleaving elements of `datasets` with `weight[i]` 161 probability of picking an element from dataset `i`. Sampling is done without 162 replacement. For example, suppose we have 2 datasets: 163 164 ```python 165 dataset1 = tf.data.Dataset.range(0, 3) 166 dataset2 = tf.data.Dataset.range(100, 103) 167 ``` 168 169 Suppose also that we sample from these 2 datasets with the following weights: 170 171 ```python 172 sample_dataset = tf.data.experimental.sample_from_datasets( 173 [dataset1, dataset2], weights=[0.5, 0.5]) 174 ``` 175 176 One possible outcome of elements in sample_dataset is: 177 178 ``` 179 print(list(sample_dataset.as_numpy_iterator())) 180 # [100, 0, 1, 101, 2, 102] 181 ``` 182 183 Args: 184 datasets: A non-empty list of `tf.data.Dataset` objects with compatible 185 structure. 186 weights: (Optional.) A list or Tensor of `len(datasets)` floating-point 187 values where `weights[i]` represents the probability to sample from 188 `datasets[i]`, or a `tf.data.Dataset` object where each element is such a 189 list. Defaults to a uniform distribution across `datasets`. 190 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random 191 seed that will be used to create the distribution. See 192 `tf.random.set_seed` for behavior. 193 stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty 194 dataset. If `False`, it skips empty datasets. It is recommended to set it 195 to `True`. Otherwise, the distribution of samples starts off as the user 196 intends, but may change as input datasets become empty. This can be 197 difficult to detect since the dataset starts off looking correct. Default 198 to `False` for backward compatibility. 199 200 Returns: 201 A dataset that interleaves elements from `datasets` at random, according to 202 `weights` if provided, otherwise with uniform probability. 203 204 Raises: 205 TypeError: If the `datasets` or `weights` arguments have the wrong type. 206 ValueError: 207 - If `datasets` is empty, or 208 - If `weights` is specified and does not match the length of `datasets`. 209 """ 210 def _shapes_are_compatible(datasets, weights): 211 if isinstance(weights, ops.Tensor): 212 return weights.shape.is_compatible_with([len(datasets)]) 213 return len(datasets) == len(weights) 214 215 def _skip_datasets_with_zero_weight(datasets, weights): 216 datasets_and_weights = [(dataset, weight) 217 for (dataset, weight) in zip(datasets, weights) 218 if weight > 0] 219 return (zip(*datasets_and_weights) if datasets_and_weights else 220 ([datasets[0].take(0)], [1.])) 221 222 if not datasets: 223 raise ValueError("`datasets` must be a non-empty list of datasets.") 224 225 if not isinstance(weights, dataset_ops.DatasetV2): 226 if weights is None: 227 # Select inputs with uniform probability. 228 logits = [[1.0] * len(datasets)] 229 230 else: 231 if not _shapes_are_compatible(datasets, weights): 232 raise ValueError("`weights` must have the same length as `datasets`.") 233 234 # Use the given `weights` as the probability of choosing the respective 235 # input. 236 if not isinstance(weights, ops.Tensor): 237 datasets, weights = _skip_datasets_with_zero_weight(datasets, weights) 238 weights = ops.convert_to_tensor(weights, name="weights") 239 if weights.dtype not in (dtypes.float32, dtypes.float64): 240 raise TypeError("`weights` must be convertible to a tensor of " 241 "`tf.float32` or `tf.float64` elements.") 242 243 # The `stateless_multinomial()` op expects log-probabilities, as opposed 244 # to weights. 245 logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) 246 247 # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it 248 # is a `Dataset`, it is possible that evaluating it has a side effect the 249 # user depends on. 250 if len(datasets) == 1: 251 return datasets[0] 252 253 def select_dataset_constant_logits(seed): 254 return array_ops.squeeze( 255 gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), 256 axis=[0, 1]) 257 258 selector_input = dataset_ops.MapDataset( 259 random_ops.RandomDataset(seed).batch(2), 260 select_dataset_constant_logits, 261 use_inter_op_parallelism=False) 262 263 else: 264 # Use each element of the given `weights` dataset as the probability of 265 # choosing the respective input. 266 # 267 # The `stateless_multinomial()` op expects log-probabilities, as opposed to 268 # weights. 269 logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) 270 271 def select_dataset_varying_logits(logits, seed): 272 return array_ops.squeeze( 273 gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), 274 axis=[0, 1]) 275 276 logits_and_seeds = dataset_ops.Dataset.zip( 277 (logits_ds, random_ops.RandomDataset(seed).batch(2))) 278 selector_input = dataset_ops.MapDataset( 279 logits_and_seeds, 280 select_dataset_varying_logits, 281 use_inter_op_parallelism=False) 282 283 return _DirectedInterleaveDataset(selector_input, datasets, 284 stop_on_empty_dataset) 285 286 287@tf_export(v1=["data.experimental.sample_from_datasets"]) 288def sample_from_datasets_v1(datasets, 289 weights=None, 290 seed=None, 291 stop_on_empty_dataset=False): 292 return dataset_ops.DatasetV1Adapter( 293 sample_from_datasets_v2(datasets, weights, seed, stop_on_empty_dataset)) 294 295 296sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__ 297 298 299@tf_export("data.experimental.choose_from_datasets", v1=[]) 300def choose_from_datasets_v2(datasets, 301 choice_dataset, 302 stop_on_empty_dataset=False): 303 """Creates a dataset that deterministically chooses elements from `datasets`. 304 305 For example, given the following datasets: 306 307 ```python 308 datasets = [tf.data.Dataset.from_tensors("foo").repeat(), 309 tf.data.Dataset.from_tensors("bar").repeat(), 310 tf.data.Dataset.from_tensors("baz").repeat()] 311 312 # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. 313 choice_dataset = tf.data.Dataset.range(3).repeat(3) 314 315 result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) 316 ``` 317 318 The elements of `result` will be: 319 320 ``` 321 "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" 322 ``` 323 324 Args: 325 datasets: A non-empty list of `tf.data.Dataset` objects with compatible 326 structure. 327 choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` 328 and `len(datasets) - 1`. 329 stop_on_empty_dataset: If `True`, selection stops if it encounters an empty 330 dataset. If `False`, it skips empty datasets. It is recommended to set it 331 to `True`. Otherwise, the selected elements start off as the user intends, 332 but may change as input datasets become empty. This can be difficult to 333 detect since the dataset starts off looking correct. Default to `False` 334 for backward compatibility. 335 336 Returns: 337 A dataset that interleaves elements from `datasets` according to the values 338 of `choice_dataset`. 339 340 Raises: 341 TypeError: If `datasets` or `choice_dataset` has the wrong type. 342 ValueError: If `datasets` is empty. 343 """ 344 if not datasets: 345 raise ValueError("`datasets` must be a non-empty list of datasets.") 346 if choice_dataset is None or not structure.are_compatible( 347 choice_dataset.element_spec, tensor_spec.TensorSpec([], dtypes.int64)): 348 raise TypeError("`choice_dataset` must be a dataset of scalar " 349 "`tf.int64` tensors.") 350 return _DirectedInterleaveDataset(choice_dataset, datasets, 351 stop_on_empty_dataset) 352 353 354@tf_export(v1=["data.experimental.choose_from_datasets"]) 355def choose_from_datasets_v1(datasets, 356 choice_dataset, 357 stop_on_empty_dataset=False): 358 return dataset_ops.DatasetV1Adapter( 359 choose_from_datasets_v2(datasets, choice_dataset, stop_on_empty_dataset)) 360 361 362choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__ 363 364if tf2.enabled(): 365 choose_from_datasets = choose_from_datasets_v2 366 sample_from_datasets = sample_from_datasets_v2 367else: 368 choose_from_datasets = choose_from_datasets_v1 369 sample_from_datasets = sample_from_datasets_v1 370