1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Batching dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.experimental.ops import get_single_element 23from tensorflow.python.data.experimental.ops import grouping 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.data.util import convert 26from tensorflow.python.data.util import nest 27from tensorflow.python.data.util import structure 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import check_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import gen_array_ops 39from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import sparse_ops 42from tensorflow.python.util import deprecation 43from tensorflow.python.util.tf_export import tf_export 44 45 46def batch_window(dataset): 47 """Batches a window of tensors. 48 49 Args: 50 dataset: the input dataset. 51 52 Returns: 53 A `Tensor` representing the batch of the entire input dataset. 54 """ 55 dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset) 56 if isinstance(dataset_output_classes, tuple): 57 raise TypeError("Input dataset expected to have a single component") 58 if dataset_output_classes is ops.Tensor: 59 return _batch_dense_window(dataset) 60 elif dataset_output_classes is sparse_tensor.SparseTensor: 61 return _batch_sparse_window(dataset) 62 else: 63 raise TypeError("Unsupported dataset type: %s" % dataset_output_classes) 64 65 66def _batch_dense_window(dataset): 67 """Batches a window of dense tensors.""" 68 69 def key_fn(_): 70 return np.int64(0) 71 72 def shape_init_fn(_): 73 return array_ops.shape(first_element) 74 75 def shape_reduce_fn(state, value): 76 check_ops.assert_equal(state, array_ops.shape(value)) 77 return state 78 79 def finalize_fn(state): 80 return state 81 82 dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) 83 if dataset_output_shapes.is_fully_defined(): 84 shape = dataset_output_shapes 85 else: 86 first_element = get_single_element.get_single_element(dataset.take(1)) 87 shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, 88 finalize_fn) 89 shape = get_single_element.get_single_element( 90 dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) 91 92 def batch_init_fn(_): 93 batch_shape = array_ops.concat([[0], shape], 0) 94 return gen_array_ops.empty( 95 batch_shape, dtype=dataset_ops.get_legacy_output_types(dataset)) 96 97 def batch_reduce_fn(state, value): 98 return array_ops.concat([state, [value]], 0) 99 100 batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) 101 return get_single_element.get_single_element( 102 dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) 103 104 105def _batch_sparse_window(dataset): 106 """Batches a window of sparse tensors.""" 107 108 def key_fn(_): 109 return np.int64(0) 110 111 def shape_init_fn(_): 112 return first_element.dense_shape 113 114 def shape_reduce_fn(state, value): 115 check_ops.assert_equal(state, value.dense_shape) 116 return state 117 118 def finalize_fn(state): 119 return state 120 121 dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) 122 if dataset_output_shapes.is_fully_defined(): 123 shape = dataset_output_shapes 124 else: 125 first_element = get_single_element.get_single_element(dataset.take(1)) 126 shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, 127 finalize_fn) 128 shape = get_single_element.get_single_element( 129 dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) 130 131 def batch_init_fn(_): 132 indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) 133 return sparse_tensor.SparseTensor( 134 indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), 135 values=constant_op.constant( 136 [], shape=[0], dtype=dataset_ops.get_legacy_output_types(dataset)), 137 dense_shape=array_ops.concat( 138 [np.array([0], dtype=np.int64), 139 math_ops.cast(shape, dtypes.int64)], 0)) 140 141 def batch_reduce_fn(state, value): 142 return sparse_ops.sparse_concat(0, [state, value]) 143 144 def reshape_fn(value): 145 return sparse_ops.sparse_reshape( 146 value, 147 array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) 148 149 batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) 150 return get_single_element.get_single_element( 151 dataset.map(reshape_fn).apply( 152 grouping.group_by_reducer(key_fn, batch_reducer))) 153 154 155@tf_export("data.experimental.dense_to_sparse_batch") 156def dense_to_sparse_batch(batch_size, row_shape): 157 """A transformation that batches ragged elements into `tf.SparseTensor`s. 158 159 Like `Dataset.padded_batch()`, this transformation combines multiple 160 consecutive elements of the dataset, which might have different 161 shapes, into a single element. The resulting element has three 162 components (`indices`, `values`, and `dense_shape`), which 163 comprise a `tf.SparseTensor` that represents the same data. The 164 `row_shape` represents the dense shape of each row in the 165 resulting `tf.SparseTensor`, to which the effective batch size is 166 prepended. For example: 167 168 ```python 169 # NOTE: The following examples use `{ ... }` to represent the 170 # contents of a dataset. 171 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 172 173 a.apply(tf.data.experimental.dense_to_sparse_batch( 174 batch_size=2, row_shape=[6])) == 175 { 176 ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices 177 ['a', 'b', 'c', 'a', 'b'], # values 178 [2, 6]), # dense_shape 179 ([[0, 0], [0, 1], [0, 2], [0, 3]], 180 ['a', 'b', 'c', 'd'], 181 [1, 6]) 182 } 183 ``` 184 185 Args: 186 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the 187 number of consecutive elements of this dataset to combine in a 188 single batch. 189 row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like 190 object representing the equivalent dense shape of a row in the 191 resulting `tf.SparseTensor`. Each element of this dataset must 192 have the same rank as `row_shape`, and must have size less 193 than or equal to `row_shape` in each dimension. 194 195 Returns: 196 A `Dataset` transformation function, which can be passed to 197 `tf.data.Dataset.apply`. 198 """ 199 200 def _apply_fn(dataset): 201 return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) 202 203 return _apply_fn 204 205 206def padded_batch_window(dataset, padded_shape, padding_value=None): 207 """Batches a window of tensors with padding. 208 209 Args: 210 dataset: the input dataset. 211 padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like 212 object representing the shape to which the input elements should be padded 213 prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a 214 `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the 215 maximum size of that dimension in each batch. 216 padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the 217 padding value to use. Defaults are `0` for numeric types and the empty 218 string for string types. If `dataset` contains `tf.SparseTensor`, this 219 value is ignored. 220 221 Returns: 222 A `Tensor` representing the batch of the entire input dataset. 223 224 Raises: 225 ValueError: if invalid arguments are provided. 226 """ 227 dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset) 228 if not issubclass(dataset_output_classes, 229 (ops.Tensor, sparse_tensor.SparseTensor)): 230 raise TypeError("Input dataset expected to have a single tensor component") 231 if issubclass(dataset_output_classes, (ops.Tensor)): 232 return _padded_batch_dense_window(dataset, padded_shape, padding_value) 233 elif issubclass(dataset_output_classes, (sparse_tensor.SparseTensor)): 234 if padding_value is not None: 235 raise ValueError("Padding value not allowed for sparse tensors") 236 return _padded_batch_sparse_window(dataset, padded_shape) 237 else: 238 raise TypeError("Unsupported dataset type: %s" % dataset_output_classes) 239 240 241def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): 242 """Batches a window of dense tensors with padding.""" 243 244 padded_shape = math_ops.cast( 245 convert.partial_shape_to_tensor(padded_shape), dtypes.int32) 246 247 def key_fn(_): 248 return np.int64(0) 249 250 def max_init_fn(_): 251 return padded_shape 252 253 def max_reduce_fn(state, value): 254 """Computes the maximum shape to pad to.""" 255 condition = math_ops.reduce_all( 256 math_ops.logical_or( 257 math_ops.less_equal(array_ops.shape(value), padded_shape), 258 math_ops.equal(padded_shape, -1))) 259 assert_op = control_flow_ops.Assert(condition, [ 260 "Actual shape greater than padded shape: ", 261 array_ops.shape(value), padded_shape 262 ]) 263 with ops.control_dependencies([assert_op]): 264 return math_ops.maximum(state, array_ops.shape(value)) 265 266 def finalize_fn(state): 267 return state 268 269 # Compute the padded shape. 270 max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) 271 padded_shape = get_single_element.get_single_element( 272 dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) 273 274 dataset_output_types = dataset_ops.get_legacy_output_types(dataset) 275 if padding_value is None: 276 if dataset_output_types == dtypes.string: 277 padding_value = "" 278 elif dataset_output_types == dtypes.bool: 279 padding_value = False 280 elif dataset_output_types == dtypes.variant: 281 raise TypeError("Unable to create padding for field of type 'variant'") 282 else: 283 padding_value = 0 284 285 def batch_init_fn(_): 286 batch_shape = array_ops.concat( 287 [np.array([0], dtype=np.int32), padded_shape], 0) 288 return gen_array_ops.empty(batch_shape, dtype=dataset_output_types) 289 290 def batch_reduce_fn(state, value): 291 return array_ops.concat([state, [value]], 0) 292 293 def pad_fn(value): 294 shape = array_ops.shape(value) 295 left = array_ops.zeros_like(shape) 296 right = padded_shape - shape 297 return array_ops.pad( 298 value, array_ops.stack([left, right], 1), constant_values=padding_value) 299 300 batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) 301 return get_single_element.get_single_element( 302 dataset.map(pad_fn).apply( 303 grouping.group_by_reducer(key_fn, batch_reducer))) 304 305 306def _padded_batch_sparse_window(dataset, padded_shape): 307 """Batches a window of sparse tensors with padding.""" 308 309 def key_fn(_): 310 return np.int64(0) 311 312 def max_init_fn(_): 313 return convert.partial_shape_to_tensor(padded_shape) 314 315 def max_reduce_fn(state, value): 316 """Computes the maximum shape to pad to.""" 317 condition = math_ops.reduce_all( 318 math_ops.logical_or( 319 math_ops.less_equal(value.dense_shape, padded_shape), 320 math_ops.equal(padded_shape, -1))) 321 assert_op = control_flow_ops.Assert(condition, [ 322 "Actual shape greater than padded shape: ", value.dense_shape, 323 padded_shape 324 ]) 325 with ops.control_dependencies([assert_op]): 326 return math_ops.maximum(state, value.dense_shape) 327 328 def finalize_fn(state): 329 return state 330 331 # Compute the padded shape. 332 max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) 333 padded_shape = get_single_element.get_single_element( 334 dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) 335 336 def batch_init_fn(_): 337 indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], 338 0) 339 return sparse_tensor.SparseTensor( 340 indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), 341 values=constant_op.constant( 342 [], shape=[0], dtype=dataset_ops.get_legacy_output_types(dataset)), 343 dense_shape=array_ops.concat( 344 [np.array([0], dtype=np.int64), padded_shape], 0)) 345 346 def batch_reduce_fn(state, value): 347 padded_value = sparse_tensor.SparseTensor( 348 indices=value.indices, values=value.values, dense_shape=padded_shape) 349 reshaped_value = sparse_ops.sparse_reshape( 350 padded_value, 351 array_ops.concat( 352 [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) 353 return sparse_ops.sparse_concat(0, [state, reshaped_value]) 354 355 reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) 356 return get_single_element.get_single_element( 357 dataset.apply(grouping.group_by_reducer(key_fn, reducer))) 358 359 360class _UnbatchDataset(dataset_ops.UnaryDataset): 361 """A dataset that splits the elements of its input into multiple elements.""" 362 363 def __init__(self, input_dataset): 364 """See `unbatch()` for more details.""" 365 input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) 366 flat_shapes = nest.flatten(input_shapes) 367 if any(s.ndims == 0 for s in flat_shapes): 368 raise ValueError("Cannot unbatch an input with scalar components.") 369 known_batch_dim = tensor_shape.Dimension(None) 370 for s in flat_shapes: 371 try: 372 known_batch_dim = known_batch_dim.merge_with(s[0]) 373 except ValueError: 374 raise ValueError("Cannot unbatch an input whose components have " 375 "different batch sizes.") 376 self._input_dataset = input_dataset 377 378 self._structure = structure.convert_legacy_structure( 379 dataset_ops.get_legacy_output_types(input_dataset), 380 nest.map_structure(lambda s: s[1:], input_shapes), 381 dataset_ops.get_legacy_output_classes(input_dataset)) 382 383 variant_tensor = ged_ops.experimental_unbatch_dataset( 384 self._input_dataset._variant_tensor, # pylint: disable=protected-access 385 **dataset_ops.flat_structure(self)) 386 super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor) 387 388 @property 389 def _element_structure(self): 390 return self._structure 391 392 393@tf_export("data.experimental.unbatch") 394def unbatch(): 395 """Splits elements of a dataset into multiple elements on the batch dimension. 396 397 For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, 398 where `B` may vary for each input element, then for each element in the 399 dataset, the unbatched dataset will contain `B` consecutive elements 400 of shape `[a0, a1, ...]`. 401 402 ```python 403 # NOTE: The following example uses `{ ... }` to represent the contents 404 # of a dataset. 405 a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } 406 407 a.apply(tf.data.experimental.unbatch()) == { 408 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} 409 ``` 410 411 Returns: 412 A `Dataset` transformation function, which can be passed to 413 `tf.data.Dataset.apply`. 414 """ 415 416 def _apply_fn(dataset): 417 """Function from `Dataset` to `Dataset` that applies the transformation.""" 418 # NOTE(mrry): We must ensure that any SparseTensors in `dataset` 419 # are normalized to the rank-1 dense representation, so that the 420 # sparse-oblivious unbatching logic will slice them 421 # appropriately. This leads to a somewhat inefficient re-encoding step 422 # for all SparseTensor components. 423 # TODO(mrry): Consider optimizing this in future if it turns out to be 424 # a bottleneck. 425 def normalize(arg, *rest): 426 # pylint: disable=protected-access 427 if rest: 428 return dataset._element_structure._to_batched_tensor_list((arg,) + rest) 429 else: 430 return dataset._element_structure._to_batched_tensor_list(arg) 431 432 normalized_dataset = dataset.map(normalize) 433 434 # NOTE(mrry): Our `map()` has lost information about the sparseness 435 # of any SparseTensor components, so re-apply the structure of the 436 # original dataset. 437 restructured_dataset = _RestructuredDataset( 438 normalized_dataset, 439 dataset_ops.get_legacy_output_types(dataset), 440 dataset_ops.get_legacy_output_shapes(dataset), 441 dataset_ops.get_legacy_output_classes(dataset), 442 allow_unsafe_cast=True) 443 return _UnbatchDataset(restructured_dataset) 444 445 return _apply_fn 446 447 448class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset): 449 """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" 450 451 def __init__(self, input_dataset, batch_size, row_shape): 452 """See `Dataset.dense_to_sparse_batch()` for more details.""" 453 if not isinstance( 454 dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType): 455 raise TypeError("DenseToSparseDataset requires an input whose elements " 456 "have a single component, whereas the input has %r." % 457 dataset_ops.get_legacy_output_types(input_dataset)) 458 self._input_dataset = input_dataset 459 self._batch_size = batch_size 460 self._row_shape = row_shape 461 self._structure = structure.SparseTensorStructure( 462 dataset_ops.get_legacy_output_types(input_dataset), 463 tensor_shape.vector(None).concatenate(self._row_shape)) 464 465 variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset( 466 self._input_dataset._variant_tensor, # pylint: disable=protected-access 467 self._batch_size, 468 row_shape=convert.partial_shape_to_tensor(self._row_shape), 469 **dataset_ops.flat_structure(self)) 470 super(_DenseToSparseBatchDataset, self).__init__(input_dataset, 471 variant_tensor) 472 473 @property 474 def _element_structure(self): 475 return self._structure 476 477 478class _RestructuredDataset(dataset_ops.UnaryDataset): 479 """An internal helper for changing the structure and shape of a dataset.""" 480 481 def __init__(self, 482 dataset, 483 output_types, 484 output_shapes=None, 485 output_classes=None, 486 allow_unsafe_cast=False): 487 """Creates a new dataset with the given output types and shapes. 488 489 The given `dataset` must have a structure that is convertible: 490 * `dataset.output_types` must be the same as `output_types` module nesting. 491 * Each shape in `dataset.output_shapes` must be compatible with each shape 492 in `output_shapes` (if given). 493 494 Note: This helper permits "unsafe casts" for shapes, equivalent to using 495 `tf.Tensor.set_shape()` where domain-specific knowledge is available. 496 497 Args: 498 dataset: A `Dataset` object. 499 output_types: A nested structure of `tf.DType` objects. 500 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. 501 If omitted, the shapes will be inherited from `dataset`. 502 output_classes: (Optional.) A nested structure of class types. 503 If omitted, the class types will be inherited from `dataset`. 504 allow_unsafe_cast: (Optional.) If `True`, the caller may switch the 505 reported output types and shapes of the restructured dataset, e.g. to 506 switch a sparse tensor represented as `tf.variant` to its user-visible 507 type and shape. 508 509 Raises: 510 ValueError: If either `output_types` or `output_shapes` is not compatible 511 with the structure of `dataset`. 512 """ 513 self._input_dataset = dataset 514 515 input_types = dataset_ops.get_legacy_output_types(dataset) 516 if not allow_unsafe_cast: 517 # Validate that the types are compatible. 518 output_types = nest.map_structure(dtypes.as_dtype, output_types) 519 flat_original_types = nest.flatten(input_types) 520 flat_new_types = nest.flatten(output_types) 521 if flat_original_types != flat_new_types: 522 raise ValueError( 523 "Dataset with output types %r cannot be restructured to have " 524 "output types %r" % 525 (dataset_ops.get_legacy_output_types(dataset), output_types)) 526 527 input_shapes = dataset_ops.get_legacy_output_shapes(dataset) 528 if output_shapes is None: 529 # Inherit shapes from the original `dataset`. 530 output_shapes = nest.pack_sequence_as( 531 output_types, nest.flatten(input_shapes)) 532 else: 533 if not allow_unsafe_cast: 534 # Validate that the shapes are compatible. 535 nest.assert_same_structure(output_types, output_shapes) 536 flat_original_shapes = nest.flatten(input_shapes) 537 flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) 538 539 for original_shape, new_shape in zip(flat_original_shapes, 540 flat_new_shapes): 541 if not original_shape.is_compatible_with(new_shape): 542 raise ValueError( 543 "Dataset with output shapes %r cannot be restructured to have " 544 "incompatible output shapes %r" % (input_shapes, 545 output_shapes)) 546 output_shapes = nest.map_structure_up_to( 547 output_types, tensor_shape.as_shape, output_shapes) 548 549 input_classes = dataset_ops.get_legacy_output_classes(dataset) 550 if output_classes is None: 551 # Inherit class types from the original `dataset`. 552 output_classes = nest.pack_sequence_as( 553 output_types, nest.flatten(input_classes)) 554 555 self._structure = structure.convert_legacy_structure( 556 output_types, output_shapes, output_classes) 557 variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access 558 super(_RestructuredDataset, self).__init__(dataset, variant_tensor) 559 560 @property 561 def _element_structure(self): 562 return self._structure 563 564 565class _MapAndBatchDataset(dataset_ops.UnaryDataset): 566 """A `Dataset` that maps a function over a batch of elements.""" 567 568 def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, 569 drop_remainder, use_legacy_function=False): 570 """See `Dataset.map()` for details.""" 571 self._input_dataset = input_dataset 572 573 self._map_func = dataset_ops.StructuredFunctionWrapper( 574 map_func, 575 "tf.data.experimental.map_and_batch()", 576 dataset=input_dataset, 577 use_legacy_function=use_legacy_function) 578 self._batch_size_t = ops.convert_to_tensor( 579 batch_size, dtype=dtypes.int64, name="batch_size") 580 self._num_parallel_calls_t = ops.convert_to_tensor( 581 num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") 582 self._drop_remainder_t = ops.convert_to_tensor( 583 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 584 585 constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) 586 if constant_drop_remainder: 587 # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) 588 # or `False` (explicitly retaining the remainder). 589 self._structure = self._map_func.output_structure._batch( # pylint: disable=protected-access 590 tensor_util.constant_value(self._batch_size_t)) 591 else: 592 self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access 593 variant_tensor = ged_ops.experimental_map_and_batch_dataset( 594 self._input_dataset._variant_tensor, # pylint: disable=protected-access 595 self._map_func.function.captured_inputs, 596 f=self._map_func.function, 597 batch_size=self._batch_size_t, 598 num_parallel_calls=self._num_parallel_calls_t, 599 drop_remainder=self._drop_remainder_t, 600 preserve_cardinality=True, 601 **dataset_ops.flat_structure(self)) 602 super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor) 603 604 def _functions(self): 605 return [self._map_func] 606 607 @property 608 def _element_structure(self): 609 return self._structure 610 611 612@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()") 613@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"]) 614def map_and_batch_with_legacy_function(map_func, 615 batch_size, 616 num_parallel_batches=None, 617 drop_remainder=False, 618 num_parallel_calls=None): 619 """Fused implementation of `map` and `batch`. 620 621 NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not 622 work with V2 functions. New uses are strongly discouraged and existing uses 623 should migrate to `map_and_batch` as this method will not be removed in V2. 624 625 Args: 626 map_func: A function mapping a nested structure of tensors to another 627 nested structure of tensors. 628 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 629 consecutive elements of this dataset to combine in a single batch. 630 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 631 representing the number of batches to create in parallel. On one hand, 632 higher values can help mitigate the effect of stragglers. On the other 633 hand, higher values can increase contention if CPU is scarce. 634 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 635 whether the last batch should be dropped in case its size is smaller than 636 desired; the default behavior is not to drop the smaller batch. 637 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 638 representing the number of elements to process in parallel. If not 639 specified, `batch_size * num_parallel_batches` elements will be processed 640 in parallel. If the value `tf.data.experimental.AUTOTUNE` is used, then 641 the number of parallel calls is set dynamically based on available CPU. 642 643 Returns: 644 A `Dataset` transformation function, which can be passed to 645 `tf.data.Dataset.apply`. 646 647 Raises: 648 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 649 specified. 650 """ 651 652 if num_parallel_batches is None and num_parallel_calls is None: 653 num_parallel_calls = batch_size 654 elif num_parallel_batches is not None and num_parallel_calls is None: 655 num_parallel_calls = batch_size * num_parallel_batches 656 elif num_parallel_batches is not None and num_parallel_calls is not None: 657 raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " 658 "arguments are mutually exclusive.") 659 660 def _apply_fn(dataset): 661 return _MapAndBatchDataset(dataset, map_func, batch_size, 662 num_parallel_calls, drop_remainder, 663 use_legacy_function=True) 664 665 return _apply_fn 666 667 668@deprecation.deprecated( 669 None, 670 "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by " 671 "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data " 672 "optimizations will take care of using the fused implementation.") 673@tf_export("data.experimental.map_and_batch") 674def map_and_batch(map_func, 675 batch_size, 676 num_parallel_batches=None, 677 drop_remainder=False, 678 num_parallel_calls=None): 679 """Fused implementation of `map` and `batch`. 680 681 Maps `map_func` across `batch_size` consecutive elements of this dataset 682 and then combines them into a batch. Functionally, it is equivalent to `map` 683 followed by `batch`. However, by fusing the two transformations together, the 684 implementation can be more efficient. Surfacing this transformation in the API 685 is temporary. Once automatic input pipeline optimization is implemented, 686 the fusing of `map` and `batch` will happen automatically and this API will be 687 deprecated. 688 689 Args: 690 map_func: A function mapping a nested structure of tensors to another 691 nested structure of tensors. 692 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 693 consecutive elements of this dataset to combine in a single batch. 694 num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, 695 representing the number of batches to create in parallel. On one hand, 696 higher values can help mitigate the effect of stragglers. On the other 697 hand, higher values can increase contention if CPU is scarce. 698 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 699 whether the last batch should be dropped in case its size is smaller than 700 desired; the default behavior is not to drop the smaller batch. 701 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 702 representing the number of elements to process in parallel. If not 703 specified, `batch_size * num_parallel_batches` elements will be processed 704 in parallel. If the value `tf.data.experimental.AUTOTUNE` is used, then 705 the number of parallel calls is set dynamically based on available CPU. 706 707 Returns: 708 A `Dataset` transformation function, which can be passed to 709 `tf.data.Dataset.apply`. 710 711 Raises: 712 ValueError: If both `num_parallel_batches` and `num_parallel_calls` are 713 specified. 714 """ 715 716 if num_parallel_batches is None and num_parallel_calls is None: 717 num_parallel_calls = batch_size 718 elif num_parallel_batches is not None and num_parallel_calls is None: 719 num_parallel_calls = batch_size * num_parallel_batches 720 elif num_parallel_batches is not None and num_parallel_calls is not None: 721 raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " 722 "arguments are mutually exclusive.") 723 724 def _apply_fn(dataset): 725 return _MapAndBatchDataset(dataset, map_func, batch_size, 726 num_parallel_calls, drop_remainder) 727 728 return _apply_fn 729 730 731class _RebatchDataset(dataset_ops.UnaryDataset): 732 """A `Dataset` that divides the batch size by `num_workers`.""" 733 734 def __init__(self, input_dataset, num_workers): 735 self._input_dataset = input_dataset 736 737 def recalculate_output_shapes(output_shapes): 738 """Recalculates the output_shapes after dividing it by num_workers.""" 739 if len(output_shapes) < 1: 740 raise ValueError("Input shape should have at least one dimension.") 741 if (tensor_shape.dimension_value(output_shapes[0]) and 742 tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): 743 raise errors.InvalidArgumentError( 744 None, None, 745 "First dim of input shape: %d is not divisible by num_workers: %d" % 746 (output_shapes[0], num_workers)) 747 output_dims = [d for d in output_shapes.dims] 748 output_dims[0] = output_dims[0] // num_workers 749 return tensor_shape.TensorShape(output_dims) 750 751 input_types = dataset_ops.get_legacy_output_types(self._input_dataset) 752 input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) 753 input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) 754 output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) 755 756 self._structure = structure.convert_legacy_structure( 757 input_types, output_shapes, input_classes) 758 variant_tensor = ged_ops.experimental_rebatch_dataset( 759 self._input_dataset._variant_tensor, # pylint: disable=protected-access 760 num_workers=num_workers, 761 **dataset_ops.flat_structure(self)) 762 super(_RebatchDataset, self).__init__(input_dataset, variant_tensor) 763 764 @property 765 def _element_structure(self): 766 return self._structure 767