1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Batching dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import convert 22from tensorflow.python.data.util import nest 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 30from tensorflow.python.ops.ragged import ragged_tensor 31from tensorflow.python.util import deprecation 32from tensorflow.python.util.tf_export import tf_export 33 34 35@tf_export("data.experimental.dense_to_ragged_batch") 36def dense_to_ragged_batch(batch_size, 37 drop_remainder=False, 38 row_splits_dtype=dtypes.int64): 39 """A transformation that batches ragged elements into `tf.RaggedTensor`s. 40 41 This transformation combines multiple consecutive elements of the input 42 dataset into a single element. 43 44 Like `tf.data.Dataset.batch`, the components of the resulting element will 45 have an additional outer dimension, which will be `batch_size` (or 46 `N % batch_size` for the last element if `batch_size` does not divide the 47 number of input elements `N` evenly and `drop_remainder` is `False`). If 48 your program depends on the batches having the same outer dimension, you 49 should set the `drop_remainder` argument to `True` to prevent the smaller 50 batch from being produced. 51 52 Unlike `tf.data.Dataset.batch`, the input elements to be batched may have 53 different shapes: 54 55 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is 56 fully defined, then it is batched as normal. 57 * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains 58 one or more axes with unknown size (i.e., `shape[i]=None`), then the output 59 will contain a `tf.RaggedTensor` that is ragged up to any of such 60 dimensions. 61 * If an input element is a `tf.RaggedTensor` or any other type, then it is 62 batched as normal. 63 64 Example: 65 66 >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6)) 67 >>> dataset = dataset.map(lambda x: tf.range(x)) 68 >>> dataset.element_spec.shape 69 TensorShape([None]) 70 >>> dataset = dataset.apply( 71 ... tf.data.experimental.dense_to_ragged_batch(batch_size=2)) 72 >>> for batch in dataset: 73 ... print(batch) 74 <tf.RaggedTensor [[], [0]]> 75 <tf.RaggedTensor [[0, 1], [0, 1, 2]]> 76 <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]> 77 78 Args: 79 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 80 consecutive elements of this dataset to combine in a single batch. 81 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 82 whether the last batch should be dropped in the case it has fewer than 83 `batch_size` elements; the default behavior is not to drop the smaller 84 batch. 85 row_splits_dtype: The dtype that should be used for the `row_splits` of any 86 new ragged tensors. Existing `tf.RaggedTensor` elements do not have their 87 row_splits dtype changed. 88 89 Returns: 90 Dataset: A `Dataset`. 91 """ 92 93 def _apply_fn(dataset): 94 ragged_dataset = _DenseToRaggedDataset(dataset, row_splits_dtype) 95 return dataset_ops.BatchDataset( 96 ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder) 97 98 return _apply_fn 99 100 101@tf_export("data.experimental.dense_to_sparse_batch") 102def dense_to_sparse_batch(batch_size, row_shape): 103 """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s. 104 105 Like `Dataset.padded_batch()`, this transformation combines multiple 106 consecutive elements of the dataset, which might have different 107 shapes, into a single element. The resulting element has three 108 components (`indices`, `values`, and `dense_shape`), which 109 comprise a `tf.sparse.SparseTensor` that represents the same data. The 110 `row_shape` represents the dense shape of each row in the 111 resulting `tf.sparse.SparseTensor`, to which the effective batch size is 112 prepended. For example: 113 114 ```python 115 # NOTE: The following examples use `{ ... }` to represent the 116 # contents of a dataset. 117 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 118 119 a.apply(tf.data.experimental.dense_to_sparse_batch( 120 batch_size=2, row_shape=[6])) == 121 { 122 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices 123 ['a', 'b', 'c', 'a', 'b'], # values 124 [2, 6]), # dense_shape 125 ([[0, 0], [0, 1], [0, 2], [0, 3]], 126 ['a', 'b', 'c', 'd'], 127 [1, 6]) 128 } 129 ``` 130 131 Args: 132 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 133 consecutive elements of this dataset to combine in a single batch. 134 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object 135 representing the equivalent dense shape of a row in the resulting 136 `tf.sparse.SparseTensor`. Each element of this dataset must have the same 137 rank as `row_shape`, and must have size less than or equal to `row_shape` 138 in each dimension. 139 140 Returns: 141 A `Dataset` transformation function, which can be passed to 142 `tf.data.Dataset.apply`. 143 """ 144 145 def _apply_fn(dataset): 146 return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) 147 148 return _apply_fn 149 150 151@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()") 152@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"]) 153def map_and_batch_with_legacy_function(map_func, 154 batch_size, 155 num_parallel_batches=None, 156 drop_remainder=False, 157 num_parallel_calls=None): 158 """Fused implementation of `map` and `batch`. 159 160 NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not 161 work with V2 functions. New uses are strongly discouraged and existing uses 162 should migrate to `map_and_batch` as this method will not be removed in V2. 163 164 Args: 165 map_func: A function mapping a nested structure of tensors to another 166 nested structure of tensors. 167 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 168 consecutive elements of this dataset to combine in a single batch. 169 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 170 representing the number of batches to create in parallel. On one hand, 171 higher values can help mitigate the effect of stragglers. On the other 172 hand, higher values can increase contention if CPU is scarce. 173 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 174 whether the last batch should be dropped in case its size is smaller than 175 desired; the default behavior is not to drop the smaller batch. 176 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 177 representing the number of elements to process in parallel. If not 178 specified, `batch_size * num_parallel_batches` elements will be processed 179 in parallel. If the value `tf.data.AUTOTUNE` is used, then 180 the number of parallel calls is set dynamically based on available CPU. 181 182 Returns: 183 A `Dataset` transformation function, which can be passed to 184 `tf.data.Dataset.apply`. 185 186 Raises: 187 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 188 specified. 189 """ 190 191 if num_parallel_batches is None and num_parallel_calls is None: 192 num_parallel_calls = batch_size 193 elif num_parallel_batches is not None and num_parallel_calls is None: 194 num_parallel_calls = batch_size * num_parallel_batches 195 elif num_parallel_batches is not None and num_parallel_calls is not None: 196 raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " 197 "arguments are mutually exclusive.") 198 199 def _apply_fn(dataset): 200 return _MapAndBatchDataset(dataset, map_func, batch_size, 201 num_parallel_calls, drop_remainder, 202 use_legacy_function=True) 203 204 return _apply_fn 205 206 207@deprecation.deprecated( 208 None, 209 "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by " 210 "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data " 211 "optimizations will take care of using the fused implementation.") 212@tf_export("data.experimental.map_and_batch") 213def map_and_batch(map_func, 214 batch_size, 215 num_parallel_batches=None, 216 drop_remainder=False, 217 num_parallel_calls=None): 218 """Fused implementation of `map` and `batch`. 219 220 Maps `map_func` across `batch_size` consecutive elements of this dataset 221 and then combines them into a batch. Functionally, it is equivalent to `map` 222 followed by `batch`. This API is temporary and deprecated since input pipeline 223 optimization now fuses consecutive `map` and `batch` operations automatically. 224 225 Args: 226 map_func: A function mapping a nested structure of tensors to another 227 nested structure of tensors. 228 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 229 consecutive elements of this dataset to combine in a single batch. 230 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 231 representing the number of batches to create in parallel. On one hand, 232 higher values can help mitigate the effect of stragglers. On the other 233 hand, higher values can increase contention if CPU is scarce. 234 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 235 whether the last batch should be dropped in case its size is smaller than 236 desired; the default behavior is not to drop the smaller batch. 237 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 238 representing the number of elements to process in parallel. If not 239 specified, `batch_size * num_parallel_batches` elements will be processed 240 in parallel. If the value `tf.data.AUTOTUNE` is used, then 241 the number of parallel calls is set dynamically based on available CPU. 242 243 Returns: 244 A `Dataset` transformation function, which can be passed to 245 `tf.data.Dataset.apply`. 246 247 Raises: 248 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 249 specified. 250 """ 251 252 if num_parallel_batches is None and num_parallel_calls is None: 253 num_parallel_calls = batch_size 254 elif num_parallel_batches is not None and num_parallel_calls is None: 255 num_parallel_calls = batch_size * num_parallel_batches 256 elif num_parallel_batches is not None and num_parallel_calls is not None: 257 raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " 258 "arguments are mutually exclusive.") 259 260 def _apply_fn(dataset): 261 return _MapAndBatchDataset(dataset, map_func, batch_size, 262 num_parallel_calls, drop_remainder) 263 264 return _apply_fn 265 266 267@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.") 268@tf_export("data.experimental.unbatch") 269def unbatch(): 270 """Splits elements of a dataset into multiple elements on the batch dimension. 271 272 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 273 where `B` may vary for each input element, then for each element in the 274 dataset, the unbatched dataset will contain `B` consecutive elements 275 of shape `[a0, a1, ...]`. 276 277 ```python 278 # NOTE: The following example uses `{ ... }` to represent the contents 279 # of a dataset. 280 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 281 282 a.unbatch() == { 283 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} 284 ``` 285 286 Returns: 287 A `Dataset` transformation function, which can be passed to 288 `tf.data.Dataset.apply`. 289 """ 290 291 def _apply_fn(dataset): 292 return dataset.unbatch() 293 294 return _apply_fn 295 296 297class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): 298 """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s.""" 299 300 def __init__(self, input_dataset, batch_size, row_shape): 301 """See `Dataset.dense_to_sparse_batch()` for more details.""" 302 if not isinstance( 303 dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType): 304 raise TypeError("DenseToSparseDataset requires an input whose elements " 305 "have a single component, whereas the input has %r." % 306 dataset_ops.get_legacy_output_types(input_dataset)) 307 self._input_dataset = input_dataset 308 self._batch_size = batch_size 309 self._row_shape = row_shape 310 self._element_spec = sparse_tensor.SparseTensorSpec( 311 tensor_shape.TensorShape([None]).concatenate(self._row_shape), 312 dataset_ops.get_legacy_output_types(input_dataset)) 313 314 variant_tensor = ged_ops.dense_to_sparse_batch_dataset( 315 self._input_dataset._variant_tensor, # pylint: disable=protected-access 316 self._batch_size, 317 row_shape=convert.partial_shape_to_tensor(self._row_shape), 318 **self._flat_structure) 319 super(_DenseToSparseBatchDataset, self).__init__(input_dataset, 320 variant_tensor) 321 322 @property 323 def element_spec(self): 324 return self._element_spec 325 326 327class _MapAndBatchDataset(dataset_ops.UnaryDataset): 328 """A `Dataset` that maps a function over a batch of elements.""" 329 330 def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, 331 drop_remainder, use_legacy_function=False): 332 self._input_dataset = input_dataset 333 334 self._map_func = dataset_ops.StructuredFunctionWrapper( 335 map_func, 336 "tf.data.experimental.map_and_batch()", 337 dataset=input_dataset, 338 use_legacy_function=use_legacy_function) 339 self._batch_size_t = ops.convert_to_tensor( 340 batch_size, dtype=dtypes.int64, name="batch_size") 341 self._num_parallel_calls_t = ops.convert_to_tensor( 342 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 343 self._drop_remainder_t = ops.convert_to_tensor( 344 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 345 346 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) 347 # pylint: disable=protected-access 348 if constant_drop_remainder: 349 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 350 # or `False` (explicitly retaining the remainder). 351 # pylint: disable=g-long-lambda 352 self._element_spec = nest.map_structure( 353 lambda component_spec: component_spec._batch( 354 tensor_util.constant_value(self._batch_size_t)), 355 self._map_func.output_structure) 356 else: 357 self._element_spec = nest.map_structure( 358 lambda component_spec: component_spec._batch(None), 359 self._map_func.output_structure) 360 # pylint: enable=protected-access 361 variant_tensor = ged_ops.map_and_batch_dataset( 362 self._input_dataset._variant_tensor, # pylint: disable=protected-access 363 self._map_func.function.captured_inputs, 364 f=self._map_func.function, 365 batch_size=self._batch_size_t, 366 num_parallel_calls=self._num_parallel_calls_t, 367 drop_remainder=self._drop_remainder_t, 368 preserve_cardinality=True, 369 **self._flat_structure) 370 super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor) 371 372 def _functions(self): 373 return [self._map_func] 374 375 @property 376 def element_spec(self): 377 return self._element_spec 378 379 380class _DenseToRaggedDataset(dataset_ops.UnaryDataset): 381 """A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0). 382 383 In particular: 384 385 * Any tf.Tensor elements with rank>0 are encoded as ragged tensors with 386 ragged_rank=0. This allows tensors with varying shape to be batched 387 together. 388 * Any other elements are left as-is. 389 """ 390 391 def __init__(self, input_dataset, row_splits_dtype): 392 """Constructs a new _DenseToRaggedDataset. 393 394 Args: 395 input_dataset: The dataset whose tf.Tensor elements should be made ragged. 396 row_splits_dtype: The dtype that should be used for the `row_splits` of 397 any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* 398 have their row_splits dtype changed. 399 """ 400 # Replace each TensorSpec in the input dataset's structure with a 401 # corresponding RaggedTensorSpec. 402 def to_ragged_spec(spec): 403 """Returns the new spec based on RaggedTensors.""" 404 if (not isinstance(spec, tensor_spec.TensorSpec) or 405 spec.shape.rank is None or 406 spec.shape.is_fully_defined()): 407 return spec 408 else: 409 ragged_rank = max([ 410 axis for (axis, size) in enumerate(spec.shape.as_list()) 411 if size is None 412 ]) 413 return ragged_tensor.RaggedTensorSpec( 414 shape=spec.shape, 415 dtype=spec.dtype, 416 ragged_rank=ragged_rank, 417 row_splits_dtype=row_splits_dtype) 418 419 self._structure = nest.map_structure(to_ragged_spec, 420 input_dataset.element_spec) 421 422 # Replace each tf.Tensor value in the input dataset with a variant-encoded 423 # RaggedTensor. Since we're updating the corresponding structure to be 424 # a RaggedTensorSpec, this variant-encoded tensor will be decoded with 425 # RaggedTensorSpec._from_tensor_list. 426 def to_ragged_variant(value): 427 """Re-encode Tensors as RaggedTensors.""" 428 if (not isinstance(value, ops.Tensor) or 429 value.shape.rank is None or 430 value.shape.is_fully_defined()): 431 return value 432 else: 433 spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) 434 if spec._ragged_rank > 0: # pylint: disable=protected-access 435 value = ragged_tensor.RaggedTensor.from_tensor( 436 value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access 437 return spec._to_tensor_list(value)[0] # pylint: disable=protected-access 438 439 # Tuples are automatically unpacked by `dataset.map` so we repack them. 440 if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access 441 map_fn = lambda *value: nest.map_structure(to_ragged_variant, value) 442 else: 443 map_fn = lambda value: nest.map_structure(to_ragged_variant, value) 444 445 self._mapped_dataset = input_dataset.map(map_fn) 446 447 variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access 448 super(_DenseToRaggedDataset, self).__init__(input_dataset, variant) 449 450 @property 451 def element_spec(self): 452 return self._structure 453