• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Batching dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.experimental.ops import get_single_element
23from tensorflow.python.data.experimental.ops import grouping
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.util import convert
26from tensorflow.python.data.util import nest
27from tensorflow.python.data.util import structure
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import gen_array_ops
39from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import sparse_ops
42from tensorflow.python.util import deprecation
43from tensorflow.python.util.tf_export import tf_export
44
45
46def batch_window(dataset):
47  """Batches a window of tensors.
48
49  Args:
50    dataset: the input dataset.
51
52  Returns:
53    A `Tensor` representing the batch of the entire input dataset.
54  """
55  dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
56  if isinstance(dataset_output_classes, tuple):
57    raise TypeError("Input dataset expected to have a single component")
58  if dataset_output_classes is ops.Tensor:
59    return _batch_dense_window(dataset)
60  elif dataset_output_classes is sparse_tensor.SparseTensor:
61    return _batch_sparse_window(dataset)
62  else:
63    raise TypeError("Unsupported dataset type: %s" % dataset_output_classes)
64
65
66def _batch_dense_window(dataset):
67  """Batches a window of dense tensors."""
68
69  def key_fn(_):
70    return np.int64(0)
71
72  def shape_init_fn(_):
73    return array_ops.shape(first_element)
74
75  def shape_reduce_fn(state, value):
76    check_ops.assert_equal(state, array_ops.shape(value))
77    return state
78
79  def finalize_fn(state):
80    return state
81
82  dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
83  if dataset_output_shapes.is_fully_defined():
84    shape = dataset_output_shapes
85  else:
86    first_element = get_single_element.get_single_element(dataset.take(1))
87    shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
88                                     finalize_fn)
89    shape = get_single_element.get_single_element(
90        dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
91
92  def batch_init_fn(_):
93    batch_shape = array_ops.concat([[0], shape], 0)
94    return gen_array_ops.empty(
95        batch_shape, dtype=dataset_ops.get_legacy_output_types(dataset))
96
97  def batch_reduce_fn(state, value):
98    return array_ops.concat([state, [value]], 0)
99
100  batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
101  return get_single_element.get_single_element(
102      dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
103
104
105def _batch_sparse_window(dataset):
106  """Batches a window of sparse tensors."""
107
108  def key_fn(_):
109    return np.int64(0)
110
111  def shape_init_fn(_):
112    return first_element.dense_shape
113
114  def shape_reduce_fn(state, value):
115    check_ops.assert_equal(state, value.dense_shape)
116    return state
117
118  def finalize_fn(state):
119    return state
120
121  dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
122  if dataset_output_shapes.is_fully_defined():
123    shape = dataset_output_shapes
124  else:
125    first_element = get_single_element.get_single_element(dataset.take(1))
126    shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
127                                     finalize_fn)
128    shape = get_single_element.get_single_element(
129        dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
130
131  def batch_init_fn(_):
132    indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
133    return sparse_tensor.SparseTensor(
134        indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
135        values=constant_op.constant(
136            [], shape=[0], dtype=dataset_ops.get_legacy_output_types(dataset)),
137        dense_shape=array_ops.concat(
138            [np.array([0], dtype=np.int64),
139             math_ops.cast(shape, dtypes.int64)], 0))
140
141  def batch_reduce_fn(state, value):
142    return sparse_ops.sparse_concat(0, [state, value])
143
144  def reshape_fn(value):
145    return sparse_ops.sparse_reshape(
146        value,
147        array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
148
149  batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
150  return get_single_element.get_single_element(
151      dataset.map(reshape_fn).apply(
152          grouping.group_by_reducer(key_fn, batch_reducer)))
153
154
155@tf_export("data.experimental.dense_to_sparse_batch")
156def dense_to_sparse_batch(batch_size, row_shape):
157  """A transformation that batches ragged elements into `tf.SparseTensor`s.
158
159  Like `Dataset.padded_batch()`, this transformation combines multiple
160  consecutive elements of the dataset, which might have different
161  shapes, into a single element. The resulting element has three
162  components (`indices`, `values`, and `dense_shape`), which
163  comprise a `tf.SparseTensor` that represents the same data. The
164  `row_shape` represents the dense shape of each row in the
165  resulting `tf.SparseTensor`, to which the effective batch size is
166  prepended. For example:
167
168  ```python
169  # NOTE: The following examples use `{ ... }` to represent the
170  # contents of a dataset.
171  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
172
173  a.apply(tf.data.experimental.dense_to_sparse_batch(
174      batch_size=2, row_shape=[6])) ==
175  {
176      ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],  # indices
177       ['a', 'b', 'c', 'a', 'b'],                 # values
178       [2, 6]),                                   # dense_shape
179      ([[0, 0], [0, 1], [0, 2], [0, 3]],
180       ['a', 'b', 'c', 'd'],
181       [1, 6])
182  }
183  ```
184
185  Args:
186    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the
187      number of consecutive elements of this dataset to combine in a
188      single batch.
189    row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like
190      object representing the equivalent dense shape of a row in the
191      resulting `tf.SparseTensor`. Each element of this dataset must
192      have the same rank as `row_shape`, and must have size less
193      than or equal to `row_shape` in each dimension.
194
195  Returns:
196    A `Dataset` transformation function, which can be passed to
197    `tf.data.Dataset.apply`.
198  """
199
200  def _apply_fn(dataset):
201    return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
202
203  return _apply_fn
204
205
206def padded_batch_window(dataset, padded_shape, padding_value=None):
207  """Batches a window of tensors with padding.
208
209  Args:
210    dataset: the input dataset.
211    padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
212      object representing the shape to which the input elements should be padded
213      prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
214      `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
215      maximum size of that dimension in each batch.
216    padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
217      padding value to use. Defaults are `0` for numeric types and the empty
218      string for string types. If `dataset` contains `tf.SparseTensor`, this
219      value is ignored.
220
221  Returns:
222    A `Tensor` representing the batch of the entire input dataset.
223
224  Raises:
225    ValueError: if invalid arguments are provided.
226  """
227  dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
228  if not issubclass(dataset_output_classes,
229                    (ops.Tensor, sparse_tensor.SparseTensor)):
230    raise TypeError("Input dataset expected to have a single tensor component")
231  if issubclass(dataset_output_classes, (ops.Tensor)):
232    return _padded_batch_dense_window(dataset, padded_shape, padding_value)
233  elif issubclass(dataset_output_classes, (sparse_tensor.SparseTensor)):
234    if padding_value is not None:
235      raise ValueError("Padding value not allowed for sparse tensors")
236    return _padded_batch_sparse_window(dataset, padded_shape)
237  else:
238    raise TypeError("Unsupported dataset type: %s" % dataset_output_classes)
239
240
241def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
242  """Batches a window of dense tensors with padding."""
243
244  padded_shape = math_ops.cast(
245      convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
246
247  def key_fn(_):
248    return np.int64(0)
249
250  def max_init_fn(_):
251    return padded_shape
252
253  def max_reduce_fn(state, value):
254    """Computes the maximum shape to pad to."""
255    condition = math_ops.reduce_all(
256        math_ops.logical_or(
257            math_ops.less_equal(array_ops.shape(value), padded_shape),
258            math_ops.equal(padded_shape, -1)))
259    assert_op = control_flow_ops.Assert(condition, [
260        "Actual shape greater than padded shape: ",
261        array_ops.shape(value), padded_shape
262    ])
263    with ops.control_dependencies([assert_op]):
264      return math_ops.maximum(state, array_ops.shape(value))
265
266  def finalize_fn(state):
267    return state
268
269  # Compute the padded shape.
270  max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
271  padded_shape = get_single_element.get_single_element(
272      dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
273
274  dataset_output_types = dataset_ops.get_legacy_output_types(dataset)
275  if padding_value is None:
276    if dataset_output_types == dtypes.string:
277      padding_value = ""
278    elif dataset_output_types == dtypes.bool:
279      padding_value = False
280    elif dataset_output_types == dtypes.variant:
281      raise TypeError("Unable to create padding for field of type 'variant'")
282    else:
283      padding_value = 0
284
285  def batch_init_fn(_):
286    batch_shape = array_ops.concat(
287        [np.array([0], dtype=np.int32), padded_shape], 0)
288    return gen_array_ops.empty(batch_shape, dtype=dataset_output_types)
289
290  def batch_reduce_fn(state, value):
291    return array_ops.concat([state, [value]], 0)
292
293  def pad_fn(value):
294    shape = array_ops.shape(value)
295    left = array_ops.zeros_like(shape)
296    right = padded_shape - shape
297    return array_ops.pad(
298        value, array_ops.stack([left, right], 1), constant_values=padding_value)
299
300  batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
301  return get_single_element.get_single_element(
302      dataset.map(pad_fn).apply(
303          grouping.group_by_reducer(key_fn, batch_reducer)))
304
305
306def _padded_batch_sparse_window(dataset, padded_shape):
307  """Batches a window of sparse tensors with padding."""
308
309  def key_fn(_):
310    return np.int64(0)
311
312  def max_init_fn(_):
313    return convert.partial_shape_to_tensor(padded_shape)
314
315  def max_reduce_fn(state, value):
316    """Computes the maximum shape to pad to."""
317    condition = math_ops.reduce_all(
318        math_ops.logical_or(
319            math_ops.less_equal(value.dense_shape, padded_shape),
320            math_ops.equal(padded_shape, -1)))
321    assert_op = control_flow_ops.Assert(condition, [
322        "Actual shape greater than padded shape: ", value.dense_shape,
323        padded_shape
324    ])
325    with ops.control_dependencies([assert_op]):
326      return math_ops.maximum(state, value.dense_shape)
327
328  def finalize_fn(state):
329    return state
330
331  # Compute the padded shape.
332  max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
333  padded_shape = get_single_element.get_single_element(
334      dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
335
336  def batch_init_fn(_):
337    indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
338                                     0)
339    return sparse_tensor.SparseTensor(
340        indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
341        values=constant_op.constant(
342            [], shape=[0], dtype=dataset_ops.get_legacy_output_types(dataset)),
343        dense_shape=array_ops.concat(
344            [np.array([0], dtype=np.int64), padded_shape], 0))
345
346  def batch_reduce_fn(state, value):
347    padded_value = sparse_tensor.SparseTensor(
348        indices=value.indices, values=value.values, dense_shape=padded_shape)
349    reshaped_value = sparse_ops.sparse_reshape(
350        padded_value,
351        array_ops.concat(
352            [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
353    return sparse_ops.sparse_concat(0, [state, reshaped_value])
354
355  reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
356  return get_single_element.get_single_element(
357      dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
358
359
360class _UnbatchDataset(dataset_ops.UnaryDataset):
361  """A dataset that splits the elements of its input into multiple elements."""
362
363  def __init__(self, input_dataset):
364    """See `unbatch()` for more details."""
365    input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
366    flat_shapes = nest.flatten(input_shapes)
367    if any(s.ndims == 0 for s in flat_shapes):
368      raise ValueError("Cannot unbatch an input with scalar components.")
369    known_batch_dim = tensor_shape.Dimension(None)
370    for s in flat_shapes:
371      try:
372        known_batch_dim = known_batch_dim.merge_with(s[0])
373      except ValueError:
374        raise ValueError("Cannot unbatch an input whose components have "
375                         "different batch sizes.")
376    self._input_dataset = input_dataset
377
378    self._structure = structure.convert_legacy_structure(
379        dataset_ops.get_legacy_output_types(input_dataset),
380        nest.map_structure(lambda s: s[1:], input_shapes),
381        dataset_ops.get_legacy_output_classes(input_dataset))
382
383    variant_tensor = ged_ops.experimental_unbatch_dataset(
384        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
385        **dataset_ops.flat_structure(self))
386    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
387
388  @property
389  def _element_structure(self):
390    return self._structure
391
392
393@tf_export("data.experimental.unbatch")
394def unbatch():
395  """Splits elements of a dataset into multiple elements on the batch dimension.
396
397  For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
398  where `B` may vary for each input element, then for each element in the
399  dataset, the unbatched dataset will contain `B` consecutive elements
400  of shape `[a0, a1, ...]`.
401
402  ```python
403  # NOTE: The following example uses `{ ... }` to represent the contents
404  # of a dataset.
405  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
406
407  a.apply(tf.data.experimental.unbatch()) == {
408      'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
409  ```
410
411  Returns:
412    A `Dataset` transformation function, which can be passed to
413    `tf.data.Dataset.apply`.
414  """
415
416  def _apply_fn(dataset):
417    """Function from `Dataset` to `Dataset` that applies the transformation."""
418    # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
419    # are normalized to the rank-1 dense representation, so that the
420    # sparse-oblivious unbatching logic will slice them
421    # appropriately. This leads to a somewhat inefficient re-encoding step
422    # for all SparseTensor components.
423    # TODO(mrry): Consider optimizing this in future if it turns out to be
424    # a bottleneck.
425    def normalize(arg, *rest):
426      # pylint: disable=protected-access
427      if rest:
428        return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
429      else:
430        return dataset._element_structure._to_batched_tensor_list(arg)
431
432    normalized_dataset = dataset.map(normalize)
433
434    # NOTE(mrry): Our `map()` has lost information about the sparseness
435    # of any SparseTensor components, so re-apply the structure of the
436    # original dataset.
437    restructured_dataset = _RestructuredDataset(
438        normalized_dataset,
439        dataset_ops.get_legacy_output_types(dataset),
440        dataset_ops.get_legacy_output_shapes(dataset),
441        dataset_ops.get_legacy_output_classes(dataset),
442        allow_unsafe_cast=True)
443    return _UnbatchDataset(restructured_dataset)
444
445  return _apply_fn
446
447
448class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
449  """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
450
451  def __init__(self, input_dataset, batch_size, row_shape):
452    """See `Dataset.dense_to_sparse_batch()` for more details."""
453    if not isinstance(
454        dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
455      raise TypeError("DenseToSparseDataset requires an input whose elements "
456                      "have a single component, whereas the input has %r." %
457                      dataset_ops.get_legacy_output_types(input_dataset))
458    self._input_dataset = input_dataset
459    self._batch_size = batch_size
460    self._row_shape = row_shape
461    self._structure = structure.SparseTensorStructure(
462        dataset_ops.get_legacy_output_types(input_dataset),
463        tensor_shape.vector(None).concatenate(self._row_shape))
464
465    variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
466        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
467        self._batch_size,
468        row_shape=convert.partial_shape_to_tensor(self._row_shape),
469        **dataset_ops.flat_structure(self))
470    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
471                                                     variant_tensor)
472
473  @property
474  def _element_structure(self):
475    return self._structure
476
477
478class _RestructuredDataset(dataset_ops.UnaryDataset):
479  """An internal helper for changing the structure and shape of a dataset."""
480
481  def __init__(self,
482               dataset,
483               output_types,
484               output_shapes=None,
485               output_classes=None,
486               allow_unsafe_cast=False):
487    """Creates a new dataset with the given output types and shapes.
488
489    The given `dataset` must have a structure that is convertible:
490    * `dataset.output_types` must be the same as `output_types` module nesting.
491    * Each shape in `dataset.output_shapes` must be compatible with each shape
492      in `output_shapes` (if given).
493
494    Note: This helper permits "unsafe casts" for shapes, equivalent to using
495    `tf.Tensor.set_shape()` where domain-specific knowledge is available.
496
497    Args:
498      dataset: A `Dataset` object.
499      output_types: A nested structure of `tf.DType` objects.
500      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
501        If omitted, the shapes will be inherited from `dataset`.
502      output_classes: (Optional.) A nested structure of class types.
503        If omitted, the class types will be inherited from `dataset`.
504      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
505        reported output types and shapes of the restructured dataset, e.g. to
506        switch a sparse tensor represented as `tf.variant` to its user-visible
507        type and shape.
508
509    Raises:
510      ValueError: If either `output_types` or `output_shapes` is not compatible
511        with the structure of `dataset`.
512    """
513    self._input_dataset = dataset
514
515    input_types = dataset_ops.get_legacy_output_types(dataset)
516    if not allow_unsafe_cast:
517      # Validate that the types are compatible.
518      output_types = nest.map_structure(dtypes.as_dtype, output_types)
519      flat_original_types = nest.flatten(input_types)
520      flat_new_types = nest.flatten(output_types)
521      if flat_original_types != flat_new_types:
522        raise ValueError(
523            "Dataset with output types %r cannot be restructured to have "
524            "output types %r" %
525            (dataset_ops.get_legacy_output_types(dataset), output_types))
526
527    input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
528    if output_shapes is None:
529      # Inherit shapes from the original `dataset`.
530      output_shapes = nest.pack_sequence_as(
531          output_types, nest.flatten(input_shapes))
532    else:
533      if not allow_unsafe_cast:
534        # Validate that the shapes are compatible.
535        nest.assert_same_structure(output_types, output_shapes)
536        flat_original_shapes = nest.flatten(input_shapes)
537        flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
538
539        for original_shape, new_shape in zip(flat_original_shapes,
540                                             flat_new_shapes):
541          if not original_shape.is_compatible_with(new_shape):
542            raise ValueError(
543                "Dataset with output shapes %r cannot be restructured to have "
544                "incompatible output shapes %r" % (input_shapes,
545                                                   output_shapes))
546      output_shapes = nest.map_structure_up_to(
547          output_types, tensor_shape.as_shape, output_shapes)
548
549    input_classes = dataset_ops.get_legacy_output_classes(dataset)
550    if output_classes is None:
551      # Inherit class types from the original `dataset`.
552      output_classes = nest.pack_sequence_as(
553          output_types, nest.flatten(input_classes))
554
555    self._structure = structure.convert_legacy_structure(
556        output_types, output_shapes, output_classes)
557    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
558    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
559
560  @property
561  def _element_structure(self):
562    return self._structure
563
564
565class _MapAndBatchDataset(dataset_ops.UnaryDataset):
566  """A `Dataset` that maps a function over a batch of elements."""
567
568  def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
569               drop_remainder, use_legacy_function=False):
570    """See `Dataset.map()` for details."""
571    self._input_dataset = input_dataset
572
573    self._map_func = dataset_ops.StructuredFunctionWrapper(
574        map_func,
575        "tf.data.experimental.map_and_batch()",
576        dataset=input_dataset,
577        use_legacy_function=use_legacy_function)
578    self._batch_size_t = ops.convert_to_tensor(
579        batch_size, dtype=dtypes.int64, name="batch_size")
580    self._num_parallel_calls_t = ops.convert_to_tensor(
581        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
582    self._drop_remainder_t = ops.convert_to_tensor(
583        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
584
585    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
586    if constant_drop_remainder:
587      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
588      # or `False` (explicitly retaining the remainder).
589      self._structure = self._map_func.output_structure._batch(  # pylint: disable=protected-access
590          tensor_util.constant_value(self._batch_size_t))
591    else:
592      self._structure = self._map_func.output_structure._batch(None)  # pylint: disable=protected-access
593    variant_tensor = ged_ops.experimental_map_and_batch_dataset(
594        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
595        self._map_func.function.captured_inputs,
596        f=self._map_func.function,
597        batch_size=self._batch_size_t,
598        num_parallel_calls=self._num_parallel_calls_t,
599        drop_remainder=self._drop_remainder_t,
600        preserve_cardinality=True,
601        **dataset_ops.flat_structure(self))
602    super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
603
604  def _functions(self):
605    return [self._map_func]
606
607  @property
608  def _element_structure(self):
609    return self._structure
610
611
612@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()")
613@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"])
614def map_and_batch_with_legacy_function(map_func,
615                                       batch_size,
616                                       num_parallel_batches=None,
617                                       drop_remainder=False,
618                                       num_parallel_calls=None):
619  """Fused implementation of `map` and `batch`.
620
621  NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not
622  work with V2 functions. New uses are strongly discouraged and existing uses
623  should migrate to `map_and_batch` as this method will not be removed in V2.
624
625  Args:
626    map_func: A function mapping a nested structure of tensors to another
627      nested structure of tensors.
628    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
629      consecutive elements of this dataset to combine in a single batch.
630    num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
631      representing the number of batches to create in parallel. On one hand,
632      higher values can help mitigate the effect of stragglers. On the other
633      hand, higher values can increase contention if CPU is scarce.
634    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
635      whether the last batch should be dropped in case its size is smaller than
636      desired; the default behavior is not to drop the smaller batch.
637    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
638      representing the number of elements to process in parallel. If not
639      specified, `batch_size * num_parallel_batches` elements will be processed
640      in parallel. If the value `tf.data.experimental.AUTOTUNE` is used, then
641      the number of parallel calls is set dynamically based on available CPU.
642
643  Returns:
644    A `Dataset` transformation function, which can be passed to
645    `tf.data.Dataset.apply`.
646
647  Raises:
648    ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
649      specified.
650  """
651
652  if num_parallel_batches is None and num_parallel_calls is None:
653    num_parallel_calls = batch_size
654  elif num_parallel_batches is not None and num_parallel_calls is None:
655    num_parallel_calls = batch_size * num_parallel_batches
656  elif num_parallel_batches is not None and num_parallel_calls is not None:
657    raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
658                     "arguments are mutually exclusive.")
659
660  def _apply_fn(dataset):
661    return _MapAndBatchDataset(dataset, map_func, batch_size,
662                               num_parallel_calls, drop_remainder,
663                               use_legacy_function=True)
664
665  return _apply_fn
666
667
668@deprecation.deprecated(
669    None,
670    "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by "
671    "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data "
672    "optimizations will take care of using the fused implementation.")
673@tf_export("data.experimental.map_and_batch")
674def map_and_batch(map_func,
675                  batch_size,
676                  num_parallel_batches=None,
677                  drop_remainder=False,
678                  num_parallel_calls=None):
679  """Fused implementation of `map` and `batch`.
680
681  Maps `map_func` across `batch_size` consecutive elements of this dataset
682  and then combines them into a batch. Functionally, it is equivalent to `map`
683  followed by `batch`. However, by fusing the two transformations together, the
684  implementation can be more efficient. Surfacing this transformation in the API
685  is temporary. Once automatic input pipeline optimization is implemented,
686  the fusing of `map` and `batch` will happen automatically and this API will be
687  deprecated.
688
689  Args:
690    map_func: A function mapping a nested structure of tensors to another
691      nested structure of tensors.
692    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
693      consecutive elements of this dataset to combine in a single batch.
694    num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
695      representing the number of batches to create in parallel. On one hand,
696      higher values can help mitigate the effect of stragglers. On the other
697      hand, higher values can increase contention if CPU is scarce.
698    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
699      whether the last batch should be dropped in case its size is smaller than
700      desired; the default behavior is not to drop the smaller batch.
701    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
702      representing the number of elements to process in parallel. If not
703      specified, `batch_size * num_parallel_batches` elements will be processed
704      in parallel. If the value `tf.data.experimental.AUTOTUNE` is used, then
705      the number of parallel calls is set dynamically based on available CPU.
706
707  Returns:
708    A `Dataset` transformation function, which can be passed to
709    `tf.data.Dataset.apply`.
710
711  Raises:
712    ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
713      specified.
714  """
715
716  if num_parallel_batches is None and num_parallel_calls is None:
717    num_parallel_calls = batch_size
718  elif num_parallel_batches is not None and num_parallel_calls is None:
719    num_parallel_calls = batch_size * num_parallel_batches
720  elif num_parallel_batches is not None and num_parallel_calls is not None:
721    raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
722                     "arguments are mutually exclusive.")
723
724  def _apply_fn(dataset):
725    return _MapAndBatchDataset(dataset, map_func, batch_size,
726                               num_parallel_calls, drop_remainder)
727
728  return _apply_fn
729
730
731class _RebatchDataset(dataset_ops.UnaryDataset):
732  """A `Dataset` that divides the batch size by `num_workers`."""
733
734  def __init__(self, input_dataset, num_workers):
735    self._input_dataset = input_dataset
736
737    def recalculate_output_shapes(output_shapes):
738      """Recalculates the output_shapes after dividing it by num_workers."""
739      if len(output_shapes) < 1:
740        raise ValueError("Input shape should have at least one dimension.")
741      if (tensor_shape.dimension_value(output_shapes[0]) and
742          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
743        raise errors.InvalidArgumentError(
744            None, None,
745            "First dim of input shape: %d is not divisible by num_workers: %d" %
746            (output_shapes[0], num_workers))
747      output_dims = [d for d in output_shapes.dims]
748      output_dims[0] = output_dims[0] // num_workers
749      return tensor_shape.TensorShape(output_dims)
750
751    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
752    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
753    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
754    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
755
756    self._structure = structure.convert_legacy_structure(
757        input_types, output_shapes, input_classes)
758    variant_tensor = ged_ops.experimental_rebatch_dataset(
759        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
760        num_workers=num_workers,
761        **dataset_ops.flat_structure(self))
762    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
763
764  @property
765  def _element_structure(self):
766    return self._structure
767