1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Non-deterministic dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import convert 22from tensorflow.python.data.util import nest 23from tensorflow.python.data.util import sparse 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import function 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import gen_dataset_ops 28from tensorflow.python.util import deprecation 29 30 31class ParallelInterleaveDataset(dataset_ops.Dataset): 32 """A `Dataset` that maps a function over its input and flattens the result.""" 33 34 def __init__(self, input_dataset, map_func, cycle_length, block_length, 35 sloppy, buffer_output_elements, prefetch_input_elements): 36 """See `tf.contrib.data.parallel_interleave()` for details.""" 37 super(ParallelInterleaveDataset, self).__init__() 38 self._input_dataset = input_dataset 39 40 @function.Defun(*nest.flatten( 41 sparse.as_dense_types(input_dataset.output_types, 42 input_dataset.output_classes))) 43 def tf_map_func(*args): 44 """A wrapper for Defun that facilitates shape inference.""" 45 # Pass in shape information from the input_dataset. 46 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 47 input_dataset.output_classes) 48 for arg, shape in zip(args, nest.flatten(dense_shapes)): 49 arg.set_shape(shape) 50 51 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 52 nested_args = sparse.deserialize_sparse_tensors( 53 nested_args, input_dataset.output_types, input_dataset.output_shapes, 54 input_dataset.output_classes) 55 if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access 56 dataset = map_func(*nested_args) 57 else: 58 dataset = map_func(nested_args) 59 60 if not isinstance(dataset, dataset_ops.Dataset): 61 raise TypeError("`map_func` must return a `Dataset` object.") 62 63 self._output_classes = dataset.output_classes 64 self._output_types = dataset.output_types 65 self._output_shapes = dataset.output_shapes 66 67 return dataset._as_variant_tensor() # pylint: disable=protected-access 68 69 self._map_func = tf_map_func 70 self._map_func.add_to_graph(ops.get_default_graph()) 71 72 self._cycle_length = ops.convert_to_tensor( 73 cycle_length, dtype=dtypes.int64, name="cycle_length") 74 self._block_length = ops.convert_to_tensor( 75 block_length, dtype=dtypes.int64, name="block_length") 76 self._sloppy = ops.convert_to_tensor( 77 sloppy, dtype=dtypes.bool, name="sloppy") 78 self._buffer_output_elements = convert.optional_param_to_tensor( 79 "buffer_output_elements", 80 buffer_output_elements, 81 argument_default=2 * block_length) 82 self._prefetch_input_elements = convert.optional_param_to_tensor( 83 "prefetch_input_elements", 84 prefetch_input_elements, 85 argument_default=2 * cycle_length) 86 87 def _as_variant_tensor(self): 88 return gen_dataset_ops.parallel_interleave_dataset( 89 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 90 self._map_func.captured_inputs, 91 self._cycle_length, 92 self._block_length, 93 self._sloppy, 94 self._buffer_output_elements, 95 self._prefetch_input_elements, 96 f=self._map_func, 97 output_types=nest.flatten( 98 sparse.as_dense_types(self.output_types, self.output_classes)), 99 output_shapes=nest.flatten( 100 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 101 102 @property 103 def output_classes(self): 104 return self._output_classes 105 106 @property 107 def output_shapes(self): 108 return self._output_shapes 109 110 @property 111 def output_types(self): 112 return self._output_types 113 114 115def parallel_interleave(map_func, 116 cycle_length, 117 block_length=1, 118 sloppy=False, 119 buffer_output_elements=None, 120 prefetch_input_elements=None): 121 """A parallel version of the `Dataset.interleave()` transformation. 122 123 `parallel_interleave()` maps `map_func` across its input to produce nested 124 datasets, and outputs their elements interleaved. Unlike 125 @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested 126 datasets in parallel, which increases the throughput, especially in the 127 presence of stragglers. Furthermore, the `sloppy` argument can be used to 128 improve performance, by relaxing the requirement that the outputs are produced 129 in a deterministic order, and allowing the implementation to skip over nested 130 datasets whose elements are not readily available when requested. 131 132 Example usage: 133 134 ```python 135 # Preprocess 4 files concurrently. 136 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 137 dataset = filenames.apply( 138 tf.contrib.data.parallel_interleave( 139 lambda filename: tf.data.TFRecordDataset(filename), 140 cycle_length=4)) 141 ``` 142 143 WARNING: If `sloppy` is `True`, the order of produced elements is not 144 deterministic. 145 146 Args: 147 map_func: A function mapping a nested structure of tensors to a `Dataset`. 148 cycle_length: The number of input `Dataset`s to interleave from in parallel. 149 block_length: The number of consecutive elements to pull from an input 150 `Dataset` before advancing to the next input `Dataset`. 151 sloppy: If false, elements are produced in deterministic order. Otherwise, 152 the implementation is allowed, for the sake of expediency, to produce 153 elements in a non-deterministic order. 154 buffer_output_elements: The number of elements each iterator being 155 interleaved should buffer (similar to the `.prefetch()` transformation for 156 each interleaved iterator). 157 prefetch_input_elements: The number of input elements to transform to 158 iterators before they are needed for interleaving. 159 160 Returns: 161 A `Dataset` transformation function, which can be passed to 162 @{tf.data.Dataset.apply}. 163 """ 164 def _apply_fn(dataset): 165 return ParallelInterleaveDataset( 166 dataset, map_func, cycle_length, block_length, sloppy, 167 buffer_output_elements, prefetch_input_elements) 168 169 return _apply_fn 170 171 172@deprecation.deprecated( 173 None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.") 174def sloppy_interleave(map_func, cycle_length, block_length=1): 175 """A non-deterministic version of the `Dataset.interleave()` transformation. 176 177 `sloppy_interleave()` maps `map_func` across `dataset`, and 178 non-deterministically interleaves the results. 179 180 The resulting dataset is almost identical to `interleave`. The key 181 difference is that if retrieving a value from a given output iterator would 182 cause `get_next` to block, that iterator will be skipped, and consumed 183 when next available. If consuming from all iterators would cause the 184 `get_next` call to block, the `get_next` call blocks until the first value is 185 available. 186 187 If the underlying datasets produce elements as fast as they are consumed, the 188 `sloppy_interleave` transformation behaves identically to `interleave`. 189 However, if an underlying dataset would block the consumer, 190 `sloppy_interleave` can violate the round-robin order (that `interleave` 191 strictly obeys), producing an element from a different underlying 192 dataset instead. 193 194 Example usage: 195 196 ```python 197 # Preprocess 4 files concurrently. 198 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 199 dataset = filenames.apply( 200 tf.contrib.data.sloppy_interleave( 201 lambda filename: tf.data.TFRecordDataset(filename), 202 cycle_length=4)) 203 ``` 204 205 WARNING: The order of elements in the resulting dataset is not 206 deterministic. Use `Dataset.interleave()` if you want the elements to have a 207 deterministic order. 208 209 Args: 210 map_func: A function mapping a nested structure of tensors (having shapes 211 and types defined by `self.output_shapes` and `self.output_types`) to a 212 `Dataset`. 213 cycle_length: The number of input `Dataset`s to interleave from in parallel. 214 block_length: The number of consecutive elements to pull from an input 215 `Dataset` before advancing to the next input `Dataset`. Note: 216 `sloppy_interleave` will skip the remainder of elements in the 217 `block_length` in order to avoid blocking. 218 219 Returns: 220 A `Dataset` transformation function, which can be passed to 221 @{tf.data.Dataset.apply}. 222 """ 223 def _apply_fn(dataset): 224 return ParallelInterleaveDataset( 225 dataset, 226 map_func, 227 cycle_length, 228 block_length, 229 sloppy=True, 230 buffer_output_elements=None, 231 prefetch_input_elements=None) 232 233 return _apply_fn 234